import os import math import streamlit as st from typing import Dict, Optional from groq import Groq import cairosvg import re # -------------------------------------------------------------------- # Streamlit page configuration # -------------------------------------------------------------------- st.set_page_config( layout="wide", page_title="AI Mind Map Generator", initial_sidebar_state="expanded" ) # -------------------------------------------------------------------- # Helper Functions for Mermaid Parsing and SVG Generation # -------------------------------------------------------------------- def parse_mermaid_to_svg(mermaid_code, layout="flowchart"): """ Parses Mermaid code to extract nodes and edges and generates SVG elements based on the chosen layout. :param mermaid_code: Mermaid graph syntax as a string. :param layout: The desired layout type (e.g., "wireframe" or "flowchart"). :return: SVG content as a string. """ nodes = {} edges = [] # Extract nodes node_pattern = re.compile(r'(\w+)\[(.*?)\]') for match in node_pattern.finditer(mermaid_code): node_id, label = match.groups() nodes[node_id] = label # Extract edges edge_pattern = re.compile(r'(\w+)\s*-->\s*(\w+)') for match in edge_pattern.finditer(mermaid_code): source, target = match.groups() edges.append((source, target)) # Initialize SVG content svg_content = ''' ''' # Layout logic if layout == "flowchart": columns = 4 # Default columns for flowchart spacing_x = 250 spacing_y = 150 start_x = 100 start_y = 100 node_positions = {} for i, (node_id, label) in enumerate(nodes.items()): x = start_x + (i % columns) * spacing_x y = start_y + (i // columns) * spacing_y node_positions[node_id] = (x, y) svg_content += f''' {label} ''' # Draw edges for source, target in edges: if source in node_positions and target in node_positions: x1, y1 = node_positions[source] x2, y2 = node_positions[target] svg_content += f''' ''' elif layout == "wireframe": center_x, center_y = 600, 400 radius = max(150, min(300, 250 + 10 * len(nodes))) angle_step = 2 * math.pi / max(1, len(nodes)) # Avoid division by zero node_positions = {} for i, (node_id, label) in enumerate(nodes.items()): angle = i * angle_step x = center_x + radius * math.cos(angle) y = center_y + radius * math.sin(angle) node_positions[node_id] = (x, y) svg_content += f''' {label} ''' # Draw edges for source, target in edges: if source in node_positions and target in node_positions: x1, y1 = node_positions[source] x2, y2 = node_positions[target] svg_content += f''' ''' # Add arrowhead for edges svg_content += ''' ''' svg_content += '' return svg_content def generate_and_save_final_image(mermaid_code, layout="flowchart"): """ Generates the final image based on Mermaid code and saves it as a PNG file. :param mermaid_code: Mermaid graph syntax as a string. :param layout: The desired layout type (e.g., "wireframe" or "flowchart"). :return: Tuple (success, output_file or error message). """ try: # Parse Mermaid to SVG svg_content = parse_mermaid_to_svg(mermaid_code, layout) # Ensure the output directory exists output_directory = "images" os.makedirs(output_directory, exist_ok=True) # Define output file path output_file = os.path.join(output_directory, f"final_image_{layout}.png") # Convert SVG to PNG and save cairosvg.svg2png(bytestring=svg_content.encode("utf-8"), write_to=output_file) return True, output_file except Exception as e: return False, str(e) # -------------------------------------------------------------------- # Supported Models # -------------------------------------------------------------------- SUPPORTED_MODELS: Dict[str, str] = { "Llama 3 8B": "llama3-8b-8192", "Llama 3.2 1B (Preview)": "llama-3.2-1b-preview", "Llama 3 70B": "llama3-70b-8192", "Mixtral 8x7B": "mixtral-8x7b-32768", "Gemma 2 9B": "gemma2-9b-it", "Llama 3.2 11B Vision (Preview)": "llama-3.2-11b-vision-preview", "Llama 3.2 11B Text (Preview)": "llama-3.2-11b-text-preview", "Llama 3.1 8B Instant (Text-Only Workloads)": "llama-3.1-8b-instant", "Llama 3.2 90B Vision (Preview)": "llama-3.2-90b-vision-preview", "Llama 3.1 70B Versatile": "llama-3.1-70b-versatile", "Llama 3.1 8B Instant": "llama-3.1-8b-instant", "Llama 3.2 11B Vision (Preview)": "llama-3.2-11b-vision-preview", "Llama 3.2 1B (Preview)": "llama-3.2-1b-preview", "Llama 3.2 3B (Preview)": "llama-3.2-3b-preview", "Llama 3.2 90B Vision (Preview)": "llama-3.2-90b-vision-preview", "Llama 3.3 70B SpecDec": "llama-3.3-70b-specdec", "Llama 3.3 70B Versatile": "llama-3.3-70b-versatile", } MAX_TOKENS: int = 1500 # -------------------------------------------------------------------- # Initialize Groq client with API key # -------------------------------------------------------------------- @st.cache_resource def get_groq_client() -> Optional[Groq]: groq_api_key = os.getenv("GROQ_API_KEY") if not groq_api_key: st.error("GROQ_API_KEY not found in environment variables. Please set it and restart the app.") return None return Groq(api_key=groq_api_key) client = get_groq_client() # -------------------------------------------------------------------- # SIDEBAR # -------------------------------------------------------------------- st.sidebar.image("icon.png", width=300) st.sidebar.title("Model Configuration") selected_model = st.sidebar.selectbox("Choose an AI Model", list(SUPPORTED_MODELS.keys())) st.sidebar.subheader("Temperature") temperature = st.sidebar.slider( "Set temperature for generation variability:", min_value=0.0, max_value=1.0, value=0.7 ) # Add layout selection st.sidebar.subheader("Layout Configuration") layout = st.sidebar.radio( "Select the layout for the mind map:", options=["flowchart", "wireframe"] ) # -------------------------------------------------------------------- # MAIN CONTENT # -------------------------------------------------------------------- st.title("AI Mind Map Generator") st.markdown( """ Enter your concepts or a short description below, then click **Generate Mind Map**. The Groq LLM will produce Mermaid diagram code, which we'll display below. """ ) # Text area for user input mind_map_prompt = st.text_area( "Describe your mind map focus:", placeholder="e.g. 'Attention and Intention in personal development'" ) if st.button("Generate Mind Map"): if not mind_map_prompt.strip(): st.warning("Please provide a description or concept for the mind map.") elif client: with st.spinner("Generating your mind map..."): prompt = f""" You are an AI that generates a Mind Map in Mermaid format. The user wants a mind map about: {mind_map_prompt}. Please output ONLY the Mermaid diagram, nothing else. """ try: response = client.chat.completions.create( model=SUPPORTED_MODELS[selected_model], messages=[ {"role": "system", "content": "You are an AI that generates mind maps in Mermaid code."}, {"role": "user", "content": prompt}, ], temperature=temperature, max_tokens=MAX_TOKENS, ) mermaid_code = response.choices[0].message.content.strip() st.subheader("Generated Mind Map") st.markdown( f""" ```mermaid {mermaid_code} ``` """, unsafe_allow_html=True ) st.download_button( label="Download Mermaid Code", data=mermaid_code, file_name="mind_map_mermaid.txt", mime="text/plain" ) # Generate and display the final image based on layout success, result = generate_and_save_final_image(mermaid_code, layout) if success: st.image(result, caption=f"Generated Mind Map ({layout.capitalize()} Layout)", use_container_width=True) else: st.error(f"Failed to generate image: {result}") except Exception as e: st.error(f"Error generating mind map: {e}") else: st.error("Groq client not initialized. Make sure you have set your GROQ_API_KEY environment variable.") st.info("Built by dw — This app uses Groq LLM to generate Mermaid-based mind maps.")