Ninjasharp commited on
Commit
31fe59f
·
verified ·
1 Parent(s): 6b41db4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +278 -278
app.py CHANGED
@@ -1,278 +1,278 @@
1
- import os
2
- import math
3
- import streamlit as st
4
- from typing import Dict, Optional
5
- from groq import Groq
6
- import cairosvg
7
- import re
8
-
9
- # --------------------------------------------------------------------
10
- # Streamlit page configuration
11
- # --------------------------------------------------------------------
12
- st.set_page_config(
13
- layout="wide",
14
- page_title="AI Mind Map Generator",
15
- initial_sidebar_state="expanded"
16
- )
17
-
18
- # --------------------------------------------------------------------
19
- # Helper Functions for Mermaid Parsing and SVG Generation
20
- # --------------------------------------------------------------------
21
- def parse_mermaid_to_svg(mermaid_code, layout="flowchart"):
22
- """
23
- Parses Mermaid code to extract nodes and edges and generates SVG elements based on the chosen layout.
24
-
25
- :param mermaid_code: Mermaid graph syntax as a string.
26
- :param layout: The desired layout type (e.g., "wireframe" or "flowchart").
27
- :return: SVG content as a string.
28
- """
29
- nodes = {}
30
- edges = []
31
-
32
- # Extract nodes
33
- node_pattern = re.compile(r'(\w+)\[(.*?)\]')
34
- for match in node_pattern.finditer(mermaid_code):
35
- node_id, label = match.groups()
36
- nodes[node_id] = label
37
-
38
- # Extract edges
39
- edge_pattern = re.compile(r'(\w+)\s*-->\s*(\w+)')
40
- for match in edge_pattern.finditer(mermaid_code):
41
- source, target = match.groups()
42
- edges.append((source, target))
43
-
44
- # Initialize SVG content
45
- svg_content = '''
46
- <svg viewBox="0 0 1200 800" xmlns="http://www.w3.org/2000/svg">
47
- <!-- Background -->
48
- <rect width="1200" height="800" fill="#ffffff"/>
49
- '''
50
-
51
- # Layout logic
52
- if layout == "flowchart":
53
- columns = 4 # Default columns for flowchart
54
- spacing_x = 250
55
- spacing_y = 150
56
- start_x = 100
57
- start_y = 100
58
-
59
- node_positions = {}
60
-
61
- for i, (node_id, label) in enumerate(nodes.items()):
62
- x = start_x + (i % columns) * spacing_x
63
- y = start_y + (i // columns) * spacing_y
64
- node_positions[node_id] = (x, y)
65
-
66
- svg_content += f'''
67
- <g>
68
- <rect x="{x - 75}" y="{y - 25}" width="150" height="50" fill="#4CAF50" rx="10" ry="10"/>
69
- <text x="{x}" y="{y + 5}" text-anchor="middle" fill="white" font-family="Arial" font-size="14">{label}</text>
70
- </g>
71
- '''
72
-
73
- # Draw edges
74
- for source, target in edges:
75
- if source in node_positions and target in node_positions:
76
- x1, y1 = node_positions[source]
77
- x2, y2 = node_positions[target]
78
- svg_content += f'''
79
- <line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="#000000" stroke-width="2" marker-end="url(#arrowhead)"/>
80
- '''
81
- elif layout == "wireframe":
82
- center_x, center_y = 600, 400
83
- radius = max(150, min(300, 250 + 10 * len(nodes)))
84
- angle_step = 2 * math.pi / max(1, len(nodes)) # Avoid division by zero
85
-
86
- node_positions = {}
87
- for i, (node_id, label) in enumerate(nodes.items()):
88
- angle = i * angle_step
89
- x = center_x + radius * math.cos(angle)
90
- y = center_y + radius * math.sin(angle)
91
- node_positions[node_id] = (x, y)
92
-
93
- svg_content += f'''
94
- <g>
95
- <rect x="{x - 75}" y="{y - 25}" width="150" height="50" fill="#f39c12" rx="10" ry="10"/>
96
- <text x="{x}" y="{y + 5}" text-anchor="middle" fill="black" font-family="Arial" font-size="14">{label}</text>
97
- </g>
98
- '''
99
-
100
- # Draw edges
101
- for source, target in edges:
102
- if source in node_positions and target in node_positions:
103
- x1, y1 = node_positions[source]
104
- x2, y2 = node_positions[target]
105
- svg_content += f'''
106
- <line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="#d35400" stroke-width="2" marker-end="url(#arrowhead)"/>
107
- '''
108
-
109
- # Add arrowhead for edges
110
- svg_content += '''
111
- <defs>
112
- <marker id="arrowhead" markerWidth="10" markerHeight="7" refX="10" refY="3.5" orient="auto">
113
- <polygon points="0 0, 10 3.5, 0 7" fill="black"/>
114
- </marker>
115
- </defs>
116
- '''
117
-
118
- svg_content += '</svg>'
119
- return svg_content
120
-
121
- def generate_and_save_final_image(mermaid_code, layout="flowchart"):
122
- """
123
- Generates the final image based on Mermaid code and saves it as a PNG file.
124
-
125
- :param mermaid_code: Mermaid graph syntax as a string.
126
- :param layout: The desired layout type (e.g., "wireframe" or "flowchart").
127
- :return: Tuple (success, output_file or error message).
128
- """
129
- try:
130
- # Parse Mermaid to SVG
131
- svg_content = parse_mermaid_to_svg(mermaid_code, layout)
132
-
133
- # Ensure the output directory exists
134
- output_directory = "images"
135
- os.makedirs(output_directory, exist_ok=True)
136
-
137
- # Define output file path
138
- output_file = os.path.join(output_directory, f"final_image_{layout}.png")
139
-
140
- # Convert SVG to PNG and save
141
- cairosvg.svg2png(bytestring=svg_content.encode("utf-8"), write_to=output_file)
142
- return True, output_file
143
- except Exception as e:
144
- return False, str(e)
145
-
146
- # --------------------------------------------------------------------
147
- # Supported Models
148
- # --------------------------------------------------------------------
149
- SUPPORTED_MODELS: Dict[str, str] = {
150
- "Llama 3 8B": "llama3-8b-8192",
151
- "Llama 3.2 1B (Preview)": "llama-3.2-1b-preview",
152
- "Llama 3 70B": "llama3-70b-8192",
153
- "Mixtral 8x7B": "mixtral-8x7b-32768",
154
- "Gemma 2 9B": "gemma2-9b-it",
155
- "Llama 3.2 11B Vision (Preview)": "llama-3.2-11b-vision-preview",
156
- "Llama 3.2 11B Text (Preview)": "llama-3.2-11b-text-preview",
157
- "Llama 3.1 8B Instant (Text-Only Workloads)": "llama-3.1-8b-instant",
158
- "Llama 3.2 90B Vision (Preview)": "llama-3.2-90b-vision-preview",
159
- "Llama 3.1 70B Versatile": "llama-3.1-70b-versatile",
160
- "Llama 3.1 8B Instant": "llama-3.1-8b-instant",
161
- "Llama 3.2 11B Vision (Preview)": "llama-3.2-11b-vision-preview",
162
- "Llama 3.2 1B (Preview)": "llama-3.2-1b-preview",
163
- "Llama 3.2 3B (Preview)": "llama-3.2-3b-preview",
164
- "Llama 3.2 90B Vision (Preview)": "llama-3.2-90b-vision-preview",
165
- "Llama 3.3 70B SpecDec": "llama-3.3-70b-specdec",
166
- "Llama 3.3 70B Versatile": "llama-3.3-70b-versatile",
167
- }
168
-
169
- MAX_TOKENS: int = 1500
170
-
171
- # --------------------------------------------------------------------
172
- # Initialize Groq client with API key
173
- # --------------------------------------------------------------------
174
- @st.cache_resource
175
- def get_groq_client() -> Optional[Groq]:
176
- groq_api_key = os.getenv("GROQ_API_KEY")
177
- if not groq_api_key:
178
- st.error("GROQ_API_KEY not found in environment variables. Please set it and restart the app.")
179
- return None
180
- return Groq(api_key=groq_api_key)
181
-
182
- client = get_groq_client()
183
-
184
- # --------------------------------------------------------------------
185
- # SIDEBAR
186
- # --------------------------------------------------------------------
187
- st.sidebar.image("icon.png", width=300)
188
- st.sidebar.title("Model Configuration")
189
-
190
- selected_model = st.sidebar.selectbox("Choose an AI Model", list(SUPPORTED_MODELS.keys()))
191
-
192
- st.sidebar.subheader("Temperature")
193
- temperature = st.sidebar.slider(
194
- "Set temperature for generation variability:",
195
- min_value=0.0,
196
- max_value=1.0,
197
- value=0.7
198
- )
199
-
200
- # Add layout selection
201
- st.sidebar.subheader("Layout Configuration")
202
- layout = st.sidebar.radio(
203
- "Select the layout for the mind map:",
204
- options=["flowchart", "wireframe"]
205
- )
206
-
207
- # --------------------------------------------------------------------
208
- # MAIN CONTENT
209
- # --------------------------------------------------------------------
210
- st.title("AI Mind Map Generator")
211
- st.markdown(
212
- """
213
- Enter your concepts or a short description below, then click **Generate Mind Map**.
214
- The Groq LLM will produce Mermaid diagram code, which we'll display below.
215
- """
216
- )
217
-
218
- # Text area for user input
219
- mind_map_prompt = st.text_area(
220
- "Describe your mind map focus:",
221
- placeholder="e.g. 'Attention and Intention in personal development'"
222
- )
223
-
224
- if st.button("Generate Mind Map"):
225
- if not mind_map_prompt.strip():
226
- st.warning("Please provide a description or concept for the mind map.")
227
- elif client:
228
- with st.spinner("Generating your mind map..."):
229
- prompt = f"""
230
- You are an AI that generates a Mind Map in Mermaid format.
231
- The user wants a mind map about: {mind_map_prompt}.
232
- Please output ONLY the Mermaid diagram, nothing else.
233
- """
234
-
235
- try:
236
- response = client.chat.completions.create(
237
- model=SUPPORTED_MODELS[selected_model],
238
- messages=[
239
- {"role": "system", "content": "You are an AI that generates mind maps in Mermaid code."},
240
- {"role": "user", "content": prompt},
241
- ],
242
- temperature=temperature,
243
- max_tokens=MAX_TOKENS,
244
- )
245
-
246
- mermaid_code = response.choices[0].message.content.strip()
247
-
248
- st.subheader("Generated Mind Map")
249
- st.markdown(
250
- f"""
251
- ```mermaid
252
- {mermaid_code}
253
- ```
254
- """,
255
- unsafe_allow_html=True
256
- )
257
-
258
- st.download_button(
259
- label="Download Mermaid Code",
260
- data=mermaid_code,
261
- file_name="mind_map_mermaid.txt",
262
- mime="text/plain"
263
- )
264
-
265
- # Generate and display the final image based on layout
266
- success, result = generate_and_save_final_image(mermaid_code, layout)
267
- if success:
268
- st.image(result, caption=f"Generated Mind Map ({layout.capitalize()} Layout)", use_column_width=True)
269
- else:
270
- st.error(f"Failed to generate image: {result}")
271
-
272
- except Exception as e:
273
- st.error(f"Error generating mind map: {e}")
274
- else:
275
- st.error("Groq client not initialized. Make sure you have set your GROQ_API_KEY environment variable.")
276
-
277
- st.info("Built by dw — This app uses Groq LLM to generate Mermaid-based mind maps.")
278
-
 
1
+ import os
2
+ import math
3
+ import streamlit as st
4
+ from typing import Dict, Optional
5
+ from groq import Groq
6
+ import cairosvg
7
+ import re
8
+
9
+ # --------------------------------------------------------------------
10
+ # Streamlit page configuration
11
+ # --------------------------------------------------------------------
12
+ st.set_page_config(
13
+ layout="wide",
14
+ page_title="AI Mind Map Generator",
15
+ initial_sidebar_state="expanded"
16
+ )
17
+
18
+ # --------------------------------------------------------------------
19
+ # Helper Functions for Mermaid Parsing and SVG Generation
20
+ # --------------------------------------------------------------------
21
+ def parse_mermaid_to_svg(mermaid_code, layout="flowchart"):
22
+ """
23
+ Parses Mermaid code to extract nodes and edges and generates SVG elements based on the chosen layout.
24
+
25
+ :param mermaid_code: Mermaid graph syntax as a string.
26
+ :param layout: The desired layout type (e.g., "wireframe" or "flowchart").
27
+ :return: SVG content as a string.
28
+ """
29
+ nodes = {}
30
+ edges = []
31
+
32
+ # Extract nodes
33
+ node_pattern = re.compile(r'(\w+)\[(.*?)\]')
34
+ for match in node_pattern.finditer(mermaid_code):
35
+ node_id, label = match.groups()
36
+ nodes[node_id] = label
37
+
38
+ # Extract edges
39
+ edge_pattern = re.compile(r'(\w+)\s*-->\s*(\w+)')
40
+ for match in edge_pattern.finditer(mermaid_code):
41
+ source, target = match.groups()
42
+ edges.append((source, target))
43
+
44
+ # Initialize SVG content
45
+ svg_content = '''
46
+ <svg viewBox="0 0 1200 800" xmlns="http://www.w3.org/2000/svg">
47
+ <!-- Background -->
48
+ <rect width="1200" height="800" fill="#ffffff"/>
49
+ '''
50
+
51
+ # Layout logic
52
+ if layout == "flowchart":
53
+ columns = 4 # Default columns for flowchart
54
+ spacing_x = 250
55
+ spacing_y = 150
56
+ start_x = 100
57
+ start_y = 100
58
+
59
+ node_positions = {}
60
+
61
+ for i, (node_id, label) in enumerate(nodes.items()):
62
+ x = start_x + (i % columns) * spacing_x
63
+ y = start_y + (i // columns) * spacing_y
64
+ node_positions[node_id] = (x, y)
65
+
66
+ svg_content += f'''
67
+ <g>
68
+ <rect x="{x - 75}" y="{y - 25}" width="150" height="50" fill="#4CAF50" rx="10" ry="10"/>
69
+ <text x="{x}" y="{y + 5}" text-anchor="middle" fill="white" font-family="Arial" font-size="14">{label}</text>
70
+ </g>
71
+ '''
72
+
73
+ # Draw edges
74
+ for source, target in edges:
75
+ if source in node_positions and target in node_positions:
76
+ x1, y1 = node_positions[source]
77
+ x2, y2 = node_positions[target]
78
+ svg_content += f'''
79
+ <line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="#000000" stroke-width="2" marker-end="url(#arrowhead)"/>
80
+ '''
81
+ elif layout == "wireframe":
82
+ center_x, center_y = 600, 400
83
+ radius = max(150, min(300, 250 + 10 * len(nodes)))
84
+ angle_step = 2 * math.pi / max(1, len(nodes)) # Avoid division by zero
85
+
86
+ node_positions = {}
87
+ for i, (node_id, label) in enumerate(nodes.items()):
88
+ angle = i * angle_step
89
+ x = center_x + radius * math.cos(angle)
90
+ y = center_y + radius * math.sin(angle)
91
+ node_positions[node_id] = (x, y)
92
+
93
+ svg_content += f'''
94
+ <g>
95
+ <rect x="{x - 75}" y="{y - 25}" width="150" height="50" fill="#f39c12" rx="10" ry="10"/>
96
+ <text x="{x}" y="{y + 5}" text-anchor="middle" fill="black" font-family="Arial" font-size="14">{label}</text>
97
+ </g>
98
+ '''
99
+
100
+ # Draw edges
101
+ for source, target in edges:
102
+ if source in node_positions and target in node_positions:
103
+ x1, y1 = node_positions[source]
104
+ x2, y2 = node_positions[target]
105
+ svg_content += f'''
106
+ <line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="#d35400" stroke-width="2" marker-end="url(#arrowhead)"/>
107
+ '''
108
+
109
+ # Add arrowhead for edges
110
+ svg_content += '''
111
+ <defs>
112
+ <marker id="arrowhead" markerWidth="10" markerHeight="7" refX="10" refY="3.5" orient="auto">
113
+ <polygon points="0 0, 10 3.5, 0 7" fill="black"/>
114
+ </marker>
115
+ </defs>
116
+ '''
117
+
118
+ svg_content += '</svg>'
119
+ return svg_content
120
+
121
+ def generate_and_save_final_image(mermaid_code, layout="flowchart"):
122
+ """
123
+ Generates the final image based on Mermaid code and saves it as a PNG file.
124
+
125
+ :param mermaid_code: Mermaid graph syntax as a string.
126
+ :param layout: The desired layout type (e.g., "wireframe" or "flowchart").
127
+ :return: Tuple (success, output_file or error message).
128
+ """
129
+ try:
130
+ # Parse Mermaid to SVG
131
+ svg_content = parse_mermaid_to_svg(mermaid_code, layout)
132
+
133
+ # Ensure the output directory exists
134
+ output_directory = "images"
135
+ os.makedirs(output_directory, exist_ok=True)
136
+
137
+ # Define output file path
138
+ output_file = os.path.join(output_directory, f"final_image_{layout}.png")
139
+
140
+ # Convert SVG to PNG and save
141
+ cairosvg.svg2png(bytestring=svg_content.encode("utf-8"), write_to=output_file)
142
+ return True, output_file
143
+ except Exception as e:
144
+ return False, str(e)
145
+
146
+ # --------------------------------------------------------------------
147
+ # Supported Models
148
+ # --------------------------------------------------------------------
149
+ SUPPORTED_MODELS: Dict[str, str] = {
150
+ "Llama 3 8B": "llama3-8b-8192",
151
+ "Llama 3.2 1B (Preview)": "llama-3.2-1b-preview",
152
+ "Llama 3 70B": "llama3-70b-8192",
153
+ "Mixtral 8x7B": "mixtral-8x7b-32768",
154
+ "Gemma 2 9B": "gemma2-9b-it",
155
+ "Llama 3.2 11B Vision (Preview)": "llama-3.2-11b-vision-preview",
156
+ "Llama 3.2 11B Text (Preview)": "llama-3.2-11b-text-preview",
157
+ "Llama 3.1 8B Instant (Text-Only Workloads)": "llama-3.1-8b-instant",
158
+ "Llama 3.2 90B Vision (Preview)": "llama-3.2-90b-vision-preview",
159
+ "Llama 3.1 70B Versatile": "llama-3.1-70b-versatile",
160
+ "Llama 3.1 8B Instant": "llama-3.1-8b-instant",
161
+ "Llama 3.2 11B Vision (Preview)": "llama-3.2-11b-vision-preview",
162
+ "Llama 3.2 1B (Preview)": "llama-3.2-1b-preview",
163
+ "Llama 3.2 3B (Preview)": "llama-3.2-3b-preview",
164
+ "Llama 3.2 90B Vision (Preview)": "llama-3.2-90b-vision-preview",
165
+ "Llama 3.3 70B SpecDec": "llama-3.3-70b-specdec",
166
+ "Llama 3.3 70B Versatile": "llama-3.3-70b-versatile",
167
+ }
168
+
169
+ MAX_TOKENS: int = 1500
170
+
171
+ # --------------------------------------------------------------------
172
+ # Initialize Groq client with API key
173
+ # --------------------------------------------------------------------
174
+ @st.cache_resource
175
+ def get_groq_client() -> Optional[Groq]:
176
+ groq_api_key = os.getenv("GROQ_API_KEY")
177
+ if not groq_api_key:
178
+ st.error("GROQ_API_KEY not found in environment variables. Please set it and restart the app.")
179
+ return None
180
+ return Groq(api_key=groq_api_key)
181
+
182
+ client = get_groq_client()
183
+
184
+ # --------------------------------------------------------------------
185
+ # SIDEBAR
186
+ # --------------------------------------------------------------------
187
+ st.sidebar.image("icon.png", width=300)
188
+ st.sidebar.title("Model Configuration")
189
+
190
+ selected_model = st.sidebar.selectbox("Choose an AI Model", list(SUPPORTED_MODELS.keys()))
191
+
192
+ st.sidebar.subheader("Temperature")
193
+ temperature = st.sidebar.slider(
194
+ "Set temperature for generation variability:",
195
+ min_value=0.0,
196
+ max_value=1.0,
197
+ value=0.7
198
+ )
199
+
200
+ # Add layout selection
201
+ st.sidebar.subheader("Layout Configuration")
202
+ layout = st.sidebar.radio(
203
+ "Select the layout for the mind map:",
204
+ options=["flowchart", "wireframe"]
205
+ )
206
+
207
+ # --------------------------------------------------------------------
208
+ # MAIN CONTENT
209
+ # --------------------------------------------------------------------
210
+ st.title("AI Mind Map Generator")
211
+ st.markdown(
212
+ """
213
+ Enter your concepts or a short description below, then click **Generate Mind Map**.
214
+ The Groq LLM will produce Mermaid diagram code, which we'll display below.
215
+ """
216
+ )
217
+
218
+ # Text area for user input
219
+ mind_map_prompt = st.text_area(
220
+ "Describe your mind map focus:",
221
+ placeholder="e.g. 'Attention and Intention in personal development'"
222
+ )
223
+
224
+ if st.button("Generate Mind Map"):
225
+ if not mind_map_prompt.strip():
226
+ st.warning("Please provide a description or concept for the mind map.")
227
+ elif client:
228
+ with st.spinner("Generating your mind map..."):
229
+ prompt = f"""
230
+ You are an AI that generates a Mind Map in Mermaid format.
231
+ The user wants a mind map about: {mind_map_prompt}.
232
+ Please output ONLY the Mermaid diagram, nothing else.
233
+ """
234
+
235
+ try:
236
+ response = client.chat.completions.create(
237
+ model=SUPPORTED_MODELS[selected_model],
238
+ messages=[
239
+ {"role": "system", "content": "You are an AI that generates mind maps in Mermaid code."},
240
+ {"role": "user", "content": prompt},
241
+ ],
242
+ temperature=temperature,
243
+ max_tokens=MAX_TOKENS,
244
+ )
245
+
246
+ mermaid_code = response.choices[0].message.content.strip()
247
+
248
+ st.subheader("Generated Mind Map")
249
+ st.markdown(
250
+ f"""
251
+ ```mermaid
252
+ {mermaid_code}
253
+ ```
254
+ """,
255
+ unsafe_allow_html=True
256
+ )
257
+
258
+ st.download_button(
259
+ label="Download Mermaid Code",
260
+ data=mermaid_code,
261
+ file_name="mind_map_mermaid.txt",
262
+ mime="text/plain"
263
+ )
264
+
265
+ # Generate and display the final image based on layout
266
+ success, result = generate_and_save_final_image(mermaid_code, layout)
267
+ if success:
268
+ st.image(result, caption=f"Generated Mind Map ({layout.capitalize()} Layout)", use_container_width=True)
269
+ else:
270
+ st.error(f"Failed to generate image: {result}")
271
+
272
+ except Exception as e:
273
+ st.error(f"Error generating mind map: {e}")
274
+ else:
275
+ st.error("Groq client not initialized. Make sure you have set your GROQ_API_KEY environment variable.")
276
+
277
+ st.info("Built by dw — This app uses Groq LLM to generate Mermaid-based mind maps.")
278
+