import gradio as gr import torch import librosa import numpy as np import re import os import json import tempfile import networkx as nx import plotly.graph_objects as go from transformers import AutoProcessor, AutoModelForCausalLM import spaces model_id = "google/gemma-4-e2b-it" device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, dtype=torch.bfloat16 ).to(device) @spaces.GPU def process_audio(audio_path): if not audio_path: return None, None, "No audio provided. Please provide an audio recording detailing the architecture." try: # Load audio at 16kHz, mono audio_array, sampling_rate = librosa.load(audio_path, sr=16000, mono=True) # Pad audio securely if it is extremely short (e.g. less than 1.5 seconds) if len(audio_array) < int(1.5 * sampling_rate): audio_array = np.pad(audio_array, (0, int(1.5 * sampling_rate) - len(audio_array))) prompt = """You are an expert systems architect. Listen to the provided audio detailing a system design. Extract the core components and their relationships. Return ONLY a valid JSON object strictly matching this exact structure: { "nodes": ["Node A", "Node B"], "edges": [{"source": "Node A", "target": "Node B"}] } Do not include any other text, explanations, or code. ONLY return the raw JSON object inside a ```json block.""" messages = [ { "role": "user", "content": [ {"type": "audio", "audio": audio_array}, {"type": "text", "text": prompt}, ], }, ] inputs = processor.apply_chat_template( messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, ).to(device) with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=1500) input_length = inputs["input_ids"].shape[1] response_text = processor.decode(outputs[0][input_length:], skip_special_tokens=True) json_str = response_text if "```json" in response_text: match = re.search(r"```json\n(.*?)\n```", response_text, re.DOTALL) if match: json_str = match.group(1).strip() elif "```" in response_text: match = re.search(r"```(?:json)?\n?(.*?)\n?```", response_text, re.DOTALL) if match: json_str = match.group(1).strip() try: architecture = json.loads(json_str) except json.JSONDecodeError as e: return None, f"Model failed to return valid JSON:\n{response_text}\n\nParse Error: {e}" G = nx.DiGraph() nodes = architecture.get("nodes", []) edges = architecture.get("edges", []) for num, node in enumerate(nodes): if isinstance(node, dict) and "id" in node: G.add_node(node["id"]) elif isinstance(node, str): G.add_node(node) else: G.add_node(f"Node_{num}") for edge in edges: if isinstance(edge, dict) and "source" in edge and "target" in edge: G.add_edge(edge["source"], edge["target"]) elif isinstance(edge, list) and len(edge) >= 2: G.add_edge(edge[0], edge[1]) pos = nx.nx_pydot.graphviz_layout(G, prog='dot') fig = go.Figure() edge_x = [] edge_y = [] for edge_tuple in G.edges(): x0, y0 = pos[edge_tuple[0]] x1, y1 = pos[edge_tuple[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) fig.add_trace(go.Scatter( x=edge_x, y=edge_y, line=dict(width=1.5, color='#888'), hoverinfo='none', mode='lines' )) node_x = [] node_y = [] for node in G.nodes(): x, y = pos[node] node_x.append(x) node_y.append(y) fig.add_trace(go.Scatter( x=node_x, y=node_y, mode='markers+text', hoverinfo='text', text=list(G.nodes()), textposition="top center", marker=dict(size=30, color='darkorange', line=dict(width=2, color='black')) )) fig.update_layout( title_text='Rendered Architecture', title_font_size=20, showlegend=False, hovermode='closest', margin=dict(b=20, l=5, r=5, t=40), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False) ) output_logs = f"**Data Extracted by Gemma4:**\n\n```json\n{json.dumps(architecture, indent=2)}\n```" return fig, output_logs except Exception as e: return None, f"An error occurred in the pipeline: {str(e)}" with gr.Blocks(title="Voice-to-Graph System") as demo: gr.Markdown("# 🎙️ Voice-to-Graph System") gr.Markdown("Speak your system design out loud. The bot will process your audio and dynamically render an interactive graph directly inside this web interface.") with gr.Row(): with gr.Column(scale=1): audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Record Architecture Description") submit_btn = gr.Button("Generate Visual Architecture", variant="primary") gr.Markdown("> 💡 **Tip:** To download a high-res image of your generated graph, simply hover over the top right corner of the visualization and click the **camera icon (`Download plot as a png`)**.") with gr.Column(scale=2): graph_plot = gr.Plot(label="Interactive Architecture Graph") code_out = gr.Markdown(label="Execution Logs & Generated Code") submit_btn.click( fn=process_audio, inputs=[audio_in], outputs=[graph_plot, code_out] ) if __name__ == "__main__": demo.launch( theme=gr.themes.Ocean() )