AdityaManojShinde's picture
up
8f8f786 verified
Raw
History Blame Contribute Delete
7.51 kB
import os
import re
import tempfile
import spaces
import torch
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import mermaid as md
from mermaid.graph import Graph
load_dotenv()
MODEL_ID = "google/gemma-4-12B-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="auto", # ZeroGPU patches this; handles CPU↔GPU transparently
)
SYSTEM_PROMPT = """You are an expert database architect. When given a description, output ONLY valid Mermaid.js erDiagram syntax.
Rules:
- Start with exactly: erDiagram
- Use UPPER_CASE for entity names
- Use proper Mermaid ER relationships: ||--||, ||--o{, }o--o{, etc.
- Always include field definitions with types (int, string, date, float, bool)
- Mark primary keys with PK, foreign keys with FK
- No markdown fences, no explanation, no comments β€” raw Mermaid code only
Relationship types:
||--|| exactly one to exactly one
||--o{ exactly one to zero or many
}o--o{ zero or many to zero or many
||--|{ exactly one to one or many
Example output:
erDiagram
USER {
int id PK
string name
string email
date created_at
}
ORDER {
int id PK
int user_id FK
float total
date ordered_at
}
USER ||--o{ ORDER : places
"""
@spaces.GPU
def _run_inference(messages: list) -> str:
"""GPU-bound inference, runs inside a ZeroGPU-allocated slot."""
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
thinking=False,
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.2,
do_sample=True,
)
new_tokens = output_ids[0][inputs["input_ids"].shape[-1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
def mermaid_to_svg(mermaid_code: str):
"""Render Mermaid code to SVG using mermaid-py only."""
graph = Graph("er-diagram", mermaid_code)
diagram = md.Mermaid(graph)
svg = diagram.svg_response.text
if not svg or "<svg" not in svg:
raise ValueError(f"Empty or invalid SVG returned: {svg[:200]}")
return svg
def clean_svg(svg: str) -> str:
"""Strip <script> tags only. No custom styling injected β€” raw Mermaid output."""
svg = re.sub(r"<script[\s\S]*?</script>", "", svg, flags=re.IGNORECASE)
return svg.strip()
def clean_mermaid(raw: str) -> str:
"""Strip markdown fences and find erDiagram start."""
lines = raw.strip().splitlines()
cleaned = [l for l in lines if not l.strip().startswith("```")]
code = "\n".join(cleaned).strip()
if not code.startswith("erDiagram"):
for i, line in enumerate(cleaned):
if line.strip().startswith("erDiagram"):
code = "\n".join(cleaned[i:]).strip()
break
return code
def generate_er(description):
if not description.strip():
return (
"<p style='color:#888; padding:2rem; text-align:center;'>"
"Enter a schema description to generate a diagram.</p>",
"",
"No diagram generated yet.",
gr.update(value=None, visible=False),
)
try:
# 1. Generate Mermaid code via LLM (ZeroGPU)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": description},
]
raw_output = _run_inference(messages)
mermaid_code = clean_mermaid(raw_output)
# 2. Render via mermaid-py (no custom theme)
try:
svg = mermaid_to_svg(mermaid_code)
svg = clean_svg(svg)
status = "Rendered via mermaid-py"
except Exception as e:
return (
f"<div style='color:red; padding:1rem;'>"
f"<b>Render error:</b> {str(e)}. Paste the code below into "
f"<a href='https://mermaid.live' target='_blank'>mermaid.live</a> to debug.</div>",
mermaid_code,
f"Render failed: {str(e)}",
gr.update(value=None, visible=False),
)
# 3. Save SVG to a temp file for download
tmp_dir = tempfile.mkdtemp()
svg_path = os.path.join(tmp_dir, "er_diagram.svg")
with open(svg_path, "w", encoding="utf-8") as f:
f.write(svg)
# 4. Return raw SVG directly (no wrapper div / no custom styling)
return svg, mermaid_code, status, gr.update(value=svg_path, visible=True)
except Exception as e:
return (
f"<div style='color:red; padding:1rem;'><b>Unexpected error:</b> {str(e)}</div>",
"",
f"Error: {str(e)}",
gr.update(value=None, visible=False),
)
# ── CSS ───────────────────────────────────────────────────────────────────────
css = """
#title { text-align: center; margin-bottom: 0.25rem; }
#subtitle { text-align: center; color: #6b7280; margin-bottom: 1.5rem; font-size: 0.95rem; }
#gen-btn { background: #6366f1 !important; color: white !important; font-weight: 600; }
#gen-btn:hover { background: #4f46e5 !important; }
"""
examples = [
["E-commerce app with Users, Products, Orders, Order Items, and Categories"],
[
"Blog platform with Authors, Posts, Comments, Tags and a many-to-many between Posts and Tags"
],
[
"Hospital system with Patients, Doctors, Appointments, Prescriptions and Departments"
],
[
"University database with Students, Courses, Professors, Enrollments and Departments"
],
["Banking system with Customers, Accounts, Transactions, Loans and Branches"],
]
with gr.Blocks(title="AI ER Diagram Generator") as demo:
gr.Markdown("# 🧠 AI ER Diagram Generator", elem_id="title")
gr.Markdown(
"Describe your database schema in plain English β†’ rendered ER diagram, "
"powered by **Gemma 4** (ZeroGPU) + **mermaid-py**",
elem_id="subtitle",
)
with gr.Row():
with gr.Column(scale=1):
description = gr.Textbox(
label="Schema Description",
placeholder="e.g. E-commerce app with Users, Products, Orders and Categories...",
lines=6,
)
generate_btn = gr.Button(
"⚑ Generate Diagram", elem_id="gen-btn", variant="primary"
)
gr.Examples(examples=examples, inputs=description, label="Try an example")
with gr.Column(scale=2):
diagram_output = gr.HTML(label="ER Diagram")
download_btn = gr.File(
label="Download Diagram (SVG)", visible=False
)
status_output = gr.Markdown(value="No diagram generated yet.")
mermaid_code_output = gr.Code(
label="Mermaid Code β€” paste into mermaid.live to edit",
language="markdown",
lines=14,
)
generate_btn.click(
fn=generate_er,
inputs=description,
outputs=[diagram_output, mermaid_code_output, status_output, download_btn],
)
if __name__ == "__main__":
demo.launch(css=css)