Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import librosa | |
| import numpy as np | |
| import re | |
| import os | |
| import json | |
| import tempfile | |
| import networkx as nx | |
| import plotly.graph_objects as go | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| import spaces | |
| model_id = "google/gemma-4-e2b-it" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| dtype=torch.bfloat16 | |
| ).to(device) | |
| def process_audio(audio_path): | |
| if not audio_path: | |
| return None, None, "No audio provided. Please provide an audio recording detailing the architecture." | |
| try: | |
| # Load audio at 16kHz, mono | |
| audio_array, sampling_rate = librosa.load(audio_path, sr=16000, mono=True) | |
| # Pad audio securely if it is extremely short (e.g. less than 1.5 seconds) | |
| if len(audio_array) < int(1.5 * sampling_rate): | |
| audio_array = np.pad(audio_array, (0, int(1.5 * sampling_rate) - len(audio_array))) | |
| prompt = """You are an expert systems architect. Listen to the provided audio detailing a system design. | |
| Extract the core components and their relationships. | |
| Return ONLY a valid JSON object strictly matching this exact structure: | |
| { | |
| "nodes": ["Node A", "Node B"], | |
| "edges": [{"source": "Node A", "target": "Node B"}] | |
| } | |
| Do not include any other text, explanations, or code. ONLY return the raw JSON object inside a ```json block.""" | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "audio", "audio": audio_array}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| }, | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=1500) | |
| input_length = inputs["input_ids"].shape[1] | |
| response_text = processor.decode(outputs[0][input_length:], skip_special_tokens=True) | |
| json_str = response_text | |
| if "```json" in response_text: | |
| match = re.search(r"```json\n(.*?)\n```", response_text, re.DOTALL) | |
| if match: | |
| json_str = match.group(1).strip() | |
| elif "```" in response_text: | |
| match = re.search(r"```(?:json)?\n?(.*?)\n?```", response_text, re.DOTALL) | |
| if match: | |
| json_str = match.group(1).strip() | |
| try: | |
| architecture = json.loads(json_str) | |
| except json.JSONDecodeError as e: | |
| return None, f"Model failed to return valid JSON:\n{response_text}\n\nParse Error: {e}" | |
| G = nx.DiGraph() | |
| nodes = architecture.get("nodes", []) | |
| edges = architecture.get("edges", []) | |
| for num, node in enumerate(nodes): | |
| if isinstance(node, dict) and "id" in node: | |
| G.add_node(node["id"]) | |
| elif isinstance(node, str): | |
| G.add_node(node) | |
| else: | |
| G.add_node(f"Node_{num}") | |
| for edge in edges: | |
| if isinstance(edge, dict) and "source" in edge and "target" in edge: | |
| G.add_edge(edge["source"], edge["target"]) | |
| elif isinstance(edge, list) and len(edge) >= 2: | |
| G.add_edge(edge[0], edge[1]) | |
| pos = nx.nx_pydot.graphviz_layout(G, prog='dot') | |
| fig = go.Figure() | |
| edge_x = [] | |
| edge_y = [] | |
| for edge_tuple in G.edges(): | |
| x0, y0 = pos[edge_tuple[0]] | |
| x1, y1 = pos[edge_tuple[1]] | |
| edge_x.extend([x0, x1, None]) | |
| edge_y.extend([y0, y1, None]) | |
| fig.add_trace(go.Scatter( | |
| x=edge_x, y=edge_y, | |
| line=dict(width=1.5, color='#888'), | |
| hoverinfo='none', | |
| mode='lines' | |
| )) | |
| node_x = [] | |
| node_y = [] | |
| for node in G.nodes(): | |
| x, y = pos[node] | |
| node_x.append(x) | |
| node_y.append(y) | |
| fig.add_trace(go.Scatter( | |
| x=node_x, y=node_y, | |
| mode='markers+text', | |
| hoverinfo='text', | |
| text=list(G.nodes()), | |
| textposition="top center", | |
| marker=dict(size=30, color='darkorange', line=dict(width=2, color='black')) | |
| )) | |
| fig.update_layout( | |
| title_text='Rendered Architecture', | |
| title_font_size=20, | |
| showlegend=False, | |
| hovermode='closest', | |
| margin=dict(b=20, l=5, r=5, t=40), | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False) | |
| ) | |
| output_logs = f"**Data Extracted by Gemma4:**\n\n```json\n{json.dumps(architecture, indent=2)}\n```" | |
| return fig, output_logs | |
| except Exception as e: | |
| return None, f"An error occurred in the pipeline: {str(e)}" | |
| with gr.Blocks(title="Voice-to-Graph System") as demo: | |
| gr.Markdown("# 🎙️ Voice-to-Graph System") | |
| gr.Markdown("Speak your system design out loud. The bot will process your audio and dynamically render an interactive graph directly inside this web interface.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Record Architecture Description") | |
| submit_btn = gr.Button("Generate Visual Architecture", variant="primary") | |
| gr.Markdown("> 💡 **Tip:** To download a high-res image of your generated graph, simply hover over the top right corner of the visualization and click the **camera icon (`Download plot as a png`)**.") | |
| with gr.Column(scale=2): | |
| graph_plot = gr.Plot(label="Interactive Architecture Graph") | |
| code_out = gr.Markdown(label="Execution Logs & Generated Code") | |
| submit_btn.click( | |
| fn=process_audio, | |
| inputs=[audio_in], | |
| outputs=[graph_plot, code_out] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| theme=gr.themes.Ocean() | |
| ) |