Ninjasharp commited on
Commit
cbdb56b
·
verified ·
1 Parent(s): 194f27f

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +278 -0
  2. icon.png +0 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import streamlit as st
4
+ from typing import Dict, Optional
5
+ from groq import Groq
6
+ import cairosvg
7
+ import re
8
+
9
+ # --------------------------------------------------------------------
10
+ # Streamlit page configuration
11
+ # --------------------------------------------------------------------
12
+ st.set_page_config(
13
+ layout="wide",
14
+ page_title="AI Mind Map Generator",
15
+ initial_sidebar_state="expanded"
16
+ )
17
+
18
+ # --------------------------------------------------------------------
19
+ # Helper Functions for Mermaid Parsing and SVG Generation
20
+ # --------------------------------------------------------------------
21
+ def parse_mermaid_to_svg(mermaid_code, layout="flowchart"):
22
+ """
23
+ Parses Mermaid code to extract nodes and edges and generates SVG elements based on the chosen layout.
24
+
25
+ :param mermaid_code: Mermaid graph syntax as a string.
26
+ :param layout: The desired layout type (e.g., "wireframe" or "flowchart").
27
+ :return: SVG content as a string.
28
+ """
29
+ nodes = {}
30
+ edges = []
31
+
32
+ # Extract nodes
33
+ node_pattern = re.compile(r'(\w+)\[(.*?)\]')
34
+ for match in node_pattern.finditer(mermaid_code):
35
+ node_id, label = match.groups()
36
+ nodes[node_id] = label
37
+
38
+ # Extract edges
39
+ edge_pattern = re.compile(r'(\w+)\s*-->\s*(\w+)')
40
+ for match in edge_pattern.finditer(mermaid_code):
41
+ source, target = match.groups()
42
+ edges.append((source, target))
43
+
44
+ # Initialize SVG content
45
+ svg_content = '''
46
+ <svg viewBox="0 0 1200 800" xmlns="http://www.w3.org/2000/svg">
47
+ <!-- Background -->
48
+ <rect width="1200" height="800" fill="#ffffff"/>
49
+ '''
50
+
51
+ # Layout logic
52
+ if layout == "flowchart":
53
+ columns = 4 # Default columns for flowchart
54
+ spacing_x = 250
55
+ spacing_y = 150
56
+ start_x = 100
57
+ start_y = 100
58
+
59
+ node_positions = {}
60
+
61
+ for i, (node_id, label) in enumerate(nodes.items()):
62
+ x = start_x + (i % columns) * spacing_x
63
+ y = start_y + (i // columns) * spacing_y
64
+ node_positions[node_id] = (x, y)
65
+
66
+ svg_content += f'''
67
+ <g>
68
+ <rect x="{x - 75}" y="{y - 25}" width="150" height="50" fill="#4CAF50" rx="10" ry="10"/>
69
+ <text x="{x}" y="{y + 5}" text-anchor="middle" fill="white" font-family="Arial" font-size="14">{label}</text>
70
+ </g>
71
+ '''
72
+
73
+ # Draw edges
74
+ for source, target in edges:
75
+ if source in node_positions and target in node_positions:
76
+ x1, y1 = node_positions[source]
77
+ x2, y2 = node_positions[target]
78
+ svg_content += f'''
79
+ <line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="#000000" stroke-width="2" marker-end="url(#arrowhead)"/>
80
+ '''
81
+ elif layout == "wireframe":
82
+ center_x, center_y = 600, 400
83
+ radius = max(150, min(300, 250 + 10 * len(nodes)))
84
+ angle_step = 2 * math.pi / max(1, len(nodes)) # Avoid division by zero
85
+
86
+ node_positions = {}
87
+ for i, (node_id, label) in enumerate(nodes.items()):
88
+ angle = i * angle_step
89
+ x = center_x + radius * math.cos(angle)
90
+ y = center_y + radius * math.sin(angle)
91
+ node_positions[node_id] = (x, y)
92
+
93
+ svg_content += f'''
94
+ <g>
95
+ <rect x="{x - 75}" y="{y - 25}" width="150" height="50" fill="#f39c12" rx="10" ry="10"/>
96
+ <text x="{x}" y="{y + 5}" text-anchor="middle" fill="black" font-family="Arial" font-size="14">{label}</text>
97
+ </g>
98
+ '''
99
+
100
+ # Draw edges
101
+ for source, target in edges:
102
+ if source in node_positions and target in node_positions:
103
+ x1, y1 = node_positions[source]
104
+ x2, y2 = node_positions[target]
105
+ svg_content += f'''
106
+ <line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="#d35400" stroke-width="2" marker-end="url(#arrowhead)"/>
107
+ '''
108
+
109
+ # Add arrowhead for edges
110
+ svg_content += '''
111
+ <defs>
112
+ <marker id="arrowhead" markerWidth="10" markerHeight="7" refX="10" refY="3.5" orient="auto">
113
+ <polygon points="0 0, 10 3.5, 0 7" fill="black"/>
114
+ </marker>
115
+ </defs>
116
+ '''
117
+
118
+ svg_content += '</svg>'
119
+ return svg_content
120
+
121
+ def generate_and_save_final_image(mermaid_code, layout="flowchart"):
122
+ """
123
+ Generates the final image based on Mermaid code and saves it as a PNG file.
124
+
125
+ :param mermaid_code: Mermaid graph syntax as a string.
126
+ :param layout: The desired layout type (e.g., "wireframe" or "flowchart").
127
+ :return: Tuple (success, output_file or error message).
128
+ """
129
+ try:
130
+ # Parse Mermaid to SVG
131
+ svg_content = parse_mermaid_to_svg(mermaid_code, layout)
132
+
133
+ # Ensure the output directory exists
134
+ output_directory = "images"
135
+ os.makedirs(output_directory, exist_ok=True)
136
+
137
+ # Define output file path
138
+ output_file = os.path.join(output_directory, f"final_image_{layout}.png")
139
+
140
+ # Convert SVG to PNG and save
141
+ cairosvg.svg2png(bytestring=svg_content.encode("utf-8"), write_to=output_file)
142
+ return True, output_file
143
+ except Exception as e:
144
+ return False, str(e)
145
+
146
+ # --------------------------------------------------------------------
147
+ # Supported Models
148
+ # --------------------------------------------------------------------
149
+ SUPPORTED_MODELS: Dict[str, str] = {
150
+ "Llama 3 8B": "llama3-8b-8192",
151
+ "Llama 3.2 1B (Preview)": "llama-3.2-1b-preview",
152
+ "Llama 3 70B": "llama3-70b-8192",
153
+ "Mixtral 8x7B": "mixtral-8x7b-32768",
154
+ "Gemma 2 9B": "gemma2-9b-it",
155
+ "Llama 3.2 11B Vision (Preview)": "llama-3.2-11b-vision-preview",
156
+ "Llama 3.2 11B Text (Preview)": "llama-3.2-11b-text-preview",
157
+ "Llama 3.1 8B Instant (Text-Only Workloads)": "llama-3.1-8b-instant",
158
+ "Llama 3.2 90B Vision (Preview)": "llama-3.2-90b-vision-preview",
159
+ "Llama 3.1 70B Versatile": "llama-3.1-70b-versatile",
160
+ "Llama 3.1 8B Instant": "llama-3.1-8b-instant",
161
+ "Llama 3.2 11B Vision (Preview)": "llama-3.2-11b-vision-preview",
162
+ "Llama 3.2 1B (Preview)": "llama-3.2-1b-preview",
163
+ "Llama 3.2 3B (Preview)": "llama-3.2-3b-preview",
164
+ "Llama 3.2 90B Vision (Preview)": "llama-3.2-90b-vision-preview",
165
+ "Llama 3.3 70B SpecDec": "llama-3.3-70b-specdec",
166
+ "Llama 3.3 70B Versatile": "llama-3.3-70b-versatile",
167
+ }
168
+
169
+ MAX_TOKENS: int = 1500
170
+
171
+ # --------------------------------------------------------------------
172
+ # Initialize Groq client with API key
173
+ # --------------------------------------------------------------------
174
+ @st.cache_resource
175
+ def get_groq_client() -> Optional[Groq]:
176
+ groq_api_key = os.getenv("GROQ_API_KEY")
177
+ if not groq_api_key:
178
+ st.error("GROQ_API_KEY not found in environment variables. Please set it and restart the app.")
179
+ return None
180
+ return Groq(api_key=groq_api_key)
181
+
182
+ client = get_groq_client()
183
+
184
+ # --------------------------------------------------------------------
185
+ # SIDEBAR
186
+ # --------------------------------------------------------------------
187
+ st.sidebar.image("icon.png", width=300)
188
+ st.sidebar.title("Model Configuration")
189
+
190
+ selected_model = st.sidebar.selectbox("Choose an AI Model", list(SUPPORTED_MODELS.keys()))
191
+
192
+ st.sidebar.subheader("Temperature")
193
+ temperature = st.sidebar.slider(
194
+ "Set temperature for generation variability:",
195
+ min_value=0.0,
196
+ max_value=1.0,
197
+ value=0.7
198
+ )
199
+
200
+ # Add layout selection
201
+ st.sidebar.subheader("Layout Configuration")
202
+ layout = st.sidebar.radio(
203
+ "Select the layout for the mind map:",
204
+ options=["flowchart", "wireframe"]
205
+ )
206
+
207
+ # --------------------------------------------------------------------
208
+ # MAIN CONTENT
209
+ # --------------------------------------------------------------------
210
+ st.title("AI Mind Map Generator")
211
+ st.markdown(
212
+ """
213
+ Enter your concepts or a short description below, then click **Generate Mind Map**.
214
+ The Groq LLM will produce Mermaid diagram code, which we'll display below.
215
+ """
216
+ )
217
+
218
+ # Text area for user input
219
+ mind_map_prompt = st.text_area(
220
+ "Describe your mind map focus:",
221
+ placeholder="e.g. 'Attention and Intention in personal development'"
222
+ )
223
+
224
+ if st.button("Generate Mind Map"):
225
+ if not mind_map_prompt.strip():
226
+ st.warning("Please provide a description or concept for the mind map.")
227
+ elif client:
228
+ with st.spinner("Generating your mind map..."):
229
+ prompt = f"""
230
+ You are an AI that generates a Mind Map in Mermaid format.
231
+ The user wants a mind map about: {mind_map_prompt}.
232
+ Please output ONLY the Mermaid diagram, nothing else.
233
+ """
234
+
235
+ try:
236
+ response = client.chat.completions.create(
237
+ model=SUPPORTED_MODELS[selected_model],
238
+ messages=[
239
+ {"role": "system", "content": "You are an AI that generates mind maps in Mermaid code."},
240
+ {"role": "user", "content": prompt},
241
+ ],
242
+ temperature=temperature,
243
+ max_tokens=MAX_TOKENS,
244
+ )
245
+
246
+ mermaid_code = response.choices[0].message.content.strip()
247
+
248
+ st.subheader("Generated Mind Map")
249
+ st.markdown(
250
+ f"""
251
+ ```mermaid
252
+ {mermaid_code}
253
+ ```
254
+ """,
255
+ unsafe_allow_html=True
256
+ )
257
+
258
+ st.download_button(
259
+ label="Download Mermaid Code",
260
+ data=mermaid_code,
261
+ file_name="mind_map_mermaid.txt",
262
+ mime="text/plain"
263
+ )
264
+
265
+ # Generate and display the final image based on layout
266
+ success, result = generate_and_save_final_image(mermaid_code, layout)
267
+ if success:
268
+ st.image(result, caption=f"Generated Mind Map ({layout.capitalize()} Layout)", use_column_width=True)
269
+ else:
270
+ st.error(f"Failed to generate image: {result}")
271
+
272
+ except Exception as e:
273
+ st.error(f"Error generating mind map: {e}")
274
+ else:
275
+ st.error("Groq client not initialized. Make sure you have set your GROQ_API_KEY environment variable.")
276
+
277
+ st.info("Built by dw — This app uses Groq LLM to generate Mermaid-based mind maps.")
278
+
icon.png ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ groq
3
+ cairosvg