import os import re import tempfile import spaces import torch from dotenv import load_dotenv from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import mermaid as md from mermaid.graph import Graph load_dotenv() MODEL_ID = "google/gemma-4-12B-it" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, device_map="auto", # ZeroGPU patches this; handles CPU↔GPU transparently ) SYSTEM_PROMPT = """You are an expert database architect. When given a description, output ONLY valid Mermaid.js erDiagram syntax. Rules: - Start with exactly: erDiagram - Use UPPER_CASE for entity names - Use proper Mermaid ER relationships: ||--||, ||--o{, }o--o{, etc. - Always include field definitions with types (int, string, date, float, bool) - Mark primary keys with PK, foreign keys with FK - No markdown fences, no explanation, no comments — raw Mermaid code only Relationship types: ||--|| exactly one to exactly one ||--o{ exactly one to zero or many }o--o{ zero or many to zero or many ||--|{ exactly one to one or many Example output: erDiagram USER { int id PK string name string email date created_at } ORDER { int id PK int user_id FK float total date ordered_at } USER ||--o{ ORDER : places """ @spaces.GPU def _run_inference(messages: list) -> str: """GPU-bound inference, runs inside a ZeroGPU-allocated slot.""" prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, thinking=False, ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=1024, temperature=0.2, do_sample=True, ) new_tokens = output_ids[0][inputs["input_ids"].shape[-1]:] return tokenizer.decode(new_tokens, skip_special_tokens=True) def mermaid_to_svg(mermaid_code: str): """Render Mermaid code to SVG using mermaid-py only.""" graph = Graph("er-diagram", mermaid_code) diagram = md.Mermaid(graph) svg = diagram.svg_response.text if not svg or " str: """Strip ", "", svg, flags=re.IGNORECASE) return svg.strip() def clean_mermaid(raw: str) -> str: """Strip markdown fences and find erDiagram start.""" lines = raw.strip().splitlines() cleaned = [l for l in lines if not l.strip().startswith("```")] code = "\n".join(cleaned).strip() if not code.startswith("erDiagram"): for i, line in enumerate(cleaned): if line.strip().startswith("erDiagram"): code = "\n".join(cleaned[i:]).strip() break return code def generate_er(description): if not description.strip(): return ( "

" "Enter a schema description to generate a diagram.

", "", "No diagram generated yet.", gr.update(value=None, visible=False), ) try: # 1. Generate Mermaid code via LLM (ZeroGPU) messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": description}, ] raw_output = _run_inference(messages) mermaid_code = clean_mermaid(raw_output) # 2. Render via mermaid-py (no custom theme) try: svg = mermaid_to_svg(mermaid_code) svg = clean_svg(svg) status = "Rendered via mermaid-py" except Exception as e: return ( f"
" f"Render error: {str(e)}. Paste the code below into " f"mermaid.live to debug.
", mermaid_code, f"Render failed: {str(e)}", gr.update(value=None, visible=False), ) # 3. Save SVG to a temp file for download tmp_dir = tempfile.mkdtemp() svg_path = os.path.join(tmp_dir, "er_diagram.svg") with open(svg_path, "w", encoding="utf-8") as f: f.write(svg) # 4. Return raw SVG directly (no wrapper div / no custom styling) return svg, mermaid_code, status, gr.update(value=svg_path, visible=True) except Exception as e: return ( f"
Unexpected error: {str(e)}
", "", f"Error: {str(e)}", gr.update(value=None, visible=False), ) # ── CSS ─────────────────────────────────────────────────────────────────────── css = """ #title { text-align: center; margin-bottom: 0.25rem; } #subtitle { text-align: center; color: #6b7280; margin-bottom: 1.5rem; font-size: 0.95rem; } #gen-btn { background: #6366f1 !important; color: white !important; font-weight: 600; } #gen-btn:hover { background: #4f46e5 !important; } """ examples = [ ["E-commerce app with Users, Products, Orders, Order Items, and Categories"], [ "Blog platform with Authors, Posts, Comments, Tags and a many-to-many between Posts and Tags" ], [ "Hospital system with Patients, Doctors, Appointments, Prescriptions and Departments" ], [ "University database with Students, Courses, Professors, Enrollments and Departments" ], ["Banking system with Customers, Accounts, Transactions, Loans and Branches"], ] with gr.Blocks(title="AI ER Diagram Generator") as demo: gr.Markdown("# 🧠 AI ER Diagram Generator", elem_id="title") gr.Markdown( "Describe your database schema in plain English → rendered ER diagram, " "powered by **Gemma 4** (ZeroGPU) + **mermaid-py**", elem_id="subtitle", ) with gr.Row(): with gr.Column(scale=1): description = gr.Textbox( label="Schema Description", placeholder="e.g. E-commerce app with Users, Products, Orders and Categories...", lines=6, ) generate_btn = gr.Button( "⚡ Generate Diagram", elem_id="gen-btn", variant="primary" ) gr.Examples(examples=examples, inputs=description, label="Try an example") with gr.Column(scale=2): diagram_output = gr.HTML(label="ER Diagram") download_btn = gr.File( label="Download Diagram (SVG)", visible=False ) status_output = gr.Markdown(value="No diagram generated yet.") mermaid_code_output = gr.Code( label="Mermaid Code — paste into mermaid.live to edit", language="markdown", lines=14, ) generate_btn.click( fn=generate_er, inputs=description, outputs=[diagram_output, mermaid_code_output, status_output, download_btn], ) if __name__ == "__main__": demo.launch(css=css)