Voice-to-Graph / app.py
RohanSardar's picture
added GPU support
acccc25 verified
import gradio as gr
import torch
import librosa
import numpy as np
import re
import os
import json
import tempfile
import networkx as nx
import plotly.graph_objects as go
from transformers import AutoProcessor, AutoModelForCausalLM
import spaces
model_id = "google/gemma-4-e2b-it"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.bfloat16
).to(device)
@spaces.GPU
def process_audio(audio_path):
if not audio_path:
return None, None, "No audio provided. Please provide an audio recording detailing the architecture."
try:
# Load audio at 16kHz, mono
audio_array, sampling_rate = librosa.load(audio_path, sr=16000, mono=True)
# Pad audio securely if it is extremely short (e.g. less than 1.5 seconds)
if len(audio_array) < int(1.5 * sampling_rate):
audio_array = np.pad(audio_array, (0, int(1.5 * sampling_rate) - len(audio_array)))
prompt = """You are an expert systems architect. Listen to the provided audio detailing a system design.
Extract the core components and their relationships.
Return ONLY a valid JSON object strictly matching this exact structure:
{
"nodes": ["Node A", "Node B"],
"edges": [{"source": "Node A", "target": "Node B"}]
}
Do not include any other text, explanations, or code. ONLY return the raw JSON object inside a ```json block."""
messages = [
{
"role": "user",
"content": [
{"type": "audio", "audio": audio_array},
{"type": "text", "text": prompt},
],
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=1500)
input_length = inputs["input_ids"].shape[1]
response_text = processor.decode(outputs[0][input_length:], skip_special_tokens=True)
json_str = response_text
if "```json" in response_text:
match = re.search(r"```json\n(.*?)\n```", response_text, re.DOTALL)
if match:
json_str = match.group(1).strip()
elif "```" in response_text:
match = re.search(r"```(?:json)?\n?(.*?)\n?```", response_text, re.DOTALL)
if match:
json_str = match.group(1).strip()
try:
architecture = json.loads(json_str)
except json.JSONDecodeError as e:
return None, f"Model failed to return valid JSON:\n{response_text}\n\nParse Error: {e}"
G = nx.DiGraph()
nodes = architecture.get("nodes", [])
edges = architecture.get("edges", [])
for num, node in enumerate(nodes):
if isinstance(node, dict) and "id" in node:
G.add_node(node["id"])
elif isinstance(node, str):
G.add_node(node)
else:
G.add_node(f"Node_{num}")
for edge in edges:
if isinstance(edge, dict) and "source" in edge and "target" in edge:
G.add_edge(edge["source"], edge["target"])
elif isinstance(edge, list) and len(edge) >= 2:
G.add_edge(edge[0], edge[1])
pos = nx.nx_pydot.graphviz_layout(G, prog='dot')
fig = go.Figure()
edge_x = []
edge_y = []
for edge_tuple in G.edges():
x0, y0 = pos[edge_tuple[0]]
x1, y1 = pos[edge_tuple[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
fig.add_trace(go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=1.5, color='#888'),
hoverinfo='none',
mode='lines'
))
node_x = []
node_y = []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
fig.add_trace(go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
hoverinfo='text',
text=list(G.nodes()),
textposition="top center",
marker=dict(size=30, color='darkorange', line=dict(width=2, color='black'))
))
fig.update_layout(
title_text='Rendered Architecture',
title_font_size=20,
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
)
output_logs = f"**Data Extracted by Gemma4:**\n\n```json\n{json.dumps(architecture, indent=2)}\n```"
return fig, output_logs
except Exception as e:
return None, f"An error occurred in the pipeline: {str(e)}"
with gr.Blocks(title="Voice-to-Graph System") as demo:
gr.Markdown("# 🎙️ Voice-to-Graph System")
gr.Markdown("Speak your system design out loud. The bot will process your audio and dynamically render an interactive graph directly inside this web interface.")
with gr.Row():
with gr.Column(scale=1):
audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Record Architecture Description")
submit_btn = gr.Button("Generate Visual Architecture", variant="primary")
gr.Markdown("> 💡 **Tip:** To download a high-res image of your generated graph, simply hover over the top right corner of the visualization and click the **camera icon (`Download plot as a png`)**.")
with gr.Column(scale=2):
graph_plot = gr.Plot(label="Interactive Architecture Graph")
code_out = gr.Markdown(label="Execution Logs & Generated Code")
submit_btn.click(
fn=process_audio,
inputs=[audio_in],
outputs=[graph_plot, code_out]
)
if __name__ == "__main__":
demo.launch(
theme=gr.themes.Ocean()
)