Spaces:
Running on Zero
Running on Zero
| import os | |
| import re | |
| import tempfile | |
| import spaces | |
| import torch | |
| from dotenv import load_dotenv | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| import mermaid as md | |
| from mermaid.graph import Graph | |
| load_dotenv() | |
| MODEL_ID = "google/gemma-4-12B-it" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| dtype=torch.bfloat16, | |
| device_map="auto", # ZeroGPU patches this; handles CPUβGPU transparently | |
| ) | |
| SYSTEM_PROMPT = """You are an expert database architect. When given a description, output ONLY valid Mermaid.js erDiagram syntax. | |
| Rules: | |
| - Start with exactly: erDiagram | |
| - Use UPPER_CASE for entity names | |
| - Use proper Mermaid ER relationships: ||--||, ||--o{, }o--o{, etc. | |
| - Always include field definitions with types (int, string, date, float, bool) | |
| - Mark primary keys with PK, foreign keys with FK | |
| - No markdown fences, no explanation, no comments β raw Mermaid code only | |
| Relationship types: | |
| ||--|| exactly one to exactly one | |
| ||--o{ exactly one to zero or many | |
| }o--o{ zero or many to zero or many | |
| ||--|{ exactly one to one or many | |
| Example output: | |
| erDiagram | |
| USER { | |
| int id PK | |
| string name | |
| string email | |
| date created_at | |
| } | |
| ORDER { | |
| int id PK | |
| int user_id FK | |
| float total | |
| date ordered_at | |
| } | |
| USER ||--o{ ORDER : places | |
| """ | |
| def _run_inference(messages: list) -> str: | |
| """GPU-bound inference, runs inside a ZeroGPU-allocated slot.""" | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| thinking=False, | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| temperature=0.2, | |
| do_sample=True, | |
| ) | |
| new_tokens = output_ids[0][inputs["input_ids"].shape[-1]:] | |
| return tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| def mermaid_to_svg(mermaid_code: str): | |
| """Render Mermaid code to SVG using mermaid-py only.""" | |
| graph = Graph("er-diagram", mermaid_code) | |
| diagram = md.Mermaid(graph) | |
| svg = diagram.svg_response.text | |
| if not svg or "<svg" not in svg: | |
| raise ValueError(f"Empty or invalid SVG returned: {svg[:200]}") | |
| return svg | |
| def clean_svg(svg: str) -> str: | |
| """Strip <script> tags only. No custom styling injected β raw Mermaid output.""" | |
| svg = re.sub(r"<script[\s\S]*?</script>", "", svg, flags=re.IGNORECASE) | |
| return svg.strip() | |
| def clean_mermaid(raw: str) -> str: | |
| """Strip markdown fences and find erDiagram start.""" | |
| lines = raw.strip().splitlines() | |
| cleaned = [l for l in lines if not l.strip().startswith("```")] | |
| code = "\n".join(cleaned).strip() | |
| if not code.startswith("erDiagram"): | |
| for i, line in enumerate(cleaned): | |
| if line.strip().startswith("erDiagram"): | |
| code = "\n".join(cleaned[i:]).strip() | |
| break | |
| return code | |
| def generate_er(description): | |
| if not description.strip(): | |
| return ( | |
| "<p style='color:#888; padding:2rem; text-align:center;'>" | |
| "Enter a schema description to generate a diagram.</p>", | |
| "", | |
| "No diagram generated yet.", | |
| gr.update(value=None, visible=False), | |
| ) | |
| try: | |
| # 1. Generate Mermaid code via LLM (ZeroGPU) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": description}, | |
| ] | |
| raw_output = _run_inference(messages) | |
| mermaid_code = clean_mermaid(raw_output) | |
| # 2. Render via mermaid-py (no custom theme) | |
| try: | |
| svg = mermaid_to_svg(mermaid_code) | |
| svg = clean_svg(svg) | |
| status = "Rendered via mermaid-py" | |
| except Exception as e: | |
| return ( | |
| f"<div style='color:red; padding:1rem;'>" | |
| f"<b>Render error:</b> {str(e)}. Paste the code below into " | |
| f"<a href='https://mermaid.live' target='_blank'>mermaid.live</a> to debug.</div>", | |
| mermaid_code, | |
| f"Render failed: {str(e)}", | |
| gr.update(value=None, visible=False), | |
| ) | |
| # 3. Save SVG to a temp file for download | |
| tmp_dir = tempfile.mkdtemp() | |
| svg_path = os.path.join(tmp_dir, "er_diagram.svg") | |
| with open(svg_path, "w", encoding="utf-8") as f: | |
| f.write(svg) | |
| # 4. Return raw SVG directly (no wrapper div / no custom styling) | |
| return svg, mermaid_code, status, gr.update(value=svg_path, visible=True) | |
| except Exception as e: | |
| return ( | |
| f"<div style='color:red; padding:1rem;'><b>Unexpected error:</b> {str(e)}</div>", | |
| "", | |
| f"Error: {str(e)}", | |
| gr.update(value=None, visible=False), | |
| ) | |
| # ββ CSS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| css = """ | |
| #title { text-align: center; margin-bottom: 0.25rem; } | |
| #subtitle { text-align: center; color: #6b7280; margin-bottom: 1.5rem; font-size: 0.95rem; } | |
| #gen-btn { background: #6366f1 !important; color: white !important; font-weight: 600; } | |
| #gen-btn:hover { background: #4f46e5 !important; } | |
| """ | |
| examples = [ | |
| ["E-commerce app with Users, Products, Orders, Order Items, and Categories"], | |
| [ | |
| "Blog platform with Authors, Posts, Comments, Tags and a many-to-many between Posts and Tags" | |
| ], | |
| [ | |
| "Hospital system with Patients, Doctors, Appointments, Prescriptions and Departments" | |
| ], | |
| [ | |
| "University database with Students, Courses, Professors, Enrollments and Departments" | |
| ], | |
| ["Banking system with Customers, Accounts, Transactions, Loans and Branches"], | |
| ] | |
| with gr.Blocks(title="AI ER Diagram Generator") as demo: | |
| gr.Markdown("# π§ AI ER Diagram Generator", elem_id="title") | |
| gr.Markdown( | |
| "Describe your database schema in plain English β rendered ER diagram, " | |
| "powered by **Gemma 4** (ZeroGPU) + **mermaid-py**", | |
| elem_id="subtitle", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| description = gr.Textbox( | |
| label="Schema Description", | |
| placeholder="e.g. E-commerce app with Users, Products, Orders and Categories...", | |
| lines=6, | |
| ) | |
| generate_btn = gr.Button( | |
| "β‘ Generate Diagram", elem_id="gen-btn", variant="primary" | |
| ) | |
| gr.Examples(examples=examples, inputs=description, label="Try an example") | |
| with gr.Column(scale=2): | |
| diagram_output = gr.HTML(label="ER Diagram") | |
| download_btn = gr.File( | |
| label="Download Diagram (SVG)", visible=False | |
| ) | |
| status_output = gr.Markdown(value="No diagram generated yet.") | |
| mermaid_code_output = gr.Code( | |
| label="Mermaid Code β paste into mermaid.live to edit", | |
| language="markdown", | |
| lines=14, | |
| ) | |
| generate_btn.click( | |
| fn=generate_er, | |
| inputs=description, | |
| outputs=[diagram_output, mermaid_code_output, status_output, download_btn], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(css=css) |