TraceScene-MVP / app.py
SiddharthVenba's picture
fix: animation replay + chat uses vision model
2d2f421 verified
"""
TraceScene β€” Gradio ZeroGPU Application
Serves the custom TraceScene frontend + REST API with GPU-accelerated inference.
Architecture:
- Gradio demo at / (primary β€” required for ZeroGPU)
- Custom FastAPI routes added to Gradio's internal app for REST API
- Custom HTML/CSS/JS frontend served alongside
- @spaces.GPU wraps inference for dynamic GPU allocation
"""
import os
from pathlib import Path
import torch
import gradio as gr
import spaces
# ── Backend Imports ────────────────────────────────────────────────────
from backend.app.config import settings
from backend.app.db.database import db
from backend.app.core.inference import inference_engine, SCENE_ANALYSIS_PROMPT
from backend.app.core.scene_analyzer import SceneAnalyzer
from backend.app.core.rule_matcher import RuleMatcher
from backend.app.core.fault_deducer import FaultDeducer
from backend.app.core.report_generator import ReportGenerator
from backend.app.rules.rule_loader import rule_loader
from backend.app.utils.logger import get_logger
logger = get_logger("app")
scene_analyzer = SceneAnalyzer()
rule_matcher = RuleMatcher()
fault_deducer = FaultDeducer()
report_generator = ReportGenerator()
# ── ZeroGPU: Top-level decorated function ──────────────────────────────
# This MUST be a top-level function wired to a Gradio event handler.
_original_run_inference = inference_engine._run_inference # bound method
@spaces.GPU(duration=120)
def gpu_run_inference(image, prompt):
"""GPU-accelerated inference β€” ZeroGPU allocates GPU for this call."""
return _original_run_inference(image, prompt)
# Monkey-patch so the entire pipeline uses GPU
inference_engine._run_inference = gpu_run_inference
# ── Async helpers ──────────────────────────────────────────────────────
def run_async(coro):
"""Run async coroutine from sync Gradio context."""
import asyncio
try:
loop = asyncio.get_event_loop()
if loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, coro).result()
return loop.run_until_complete(coro)
except RuntimeError:
return asyncio.run(coro)
# ── Initialize backend ────────────────────────────────────────────────
_initialized = False
async def _ensure_init():
global _initialized
if _initialized:
return
await db.connect()
rule_loader.load_rules()
try:
inference_engine.load_model()
except Exception as e:
logger.error(f"Vision model load failed: {e}")
_initialized = True
def ensure_init():
run_async(_ensure_init())
# ── Gradio Handlers ───────────────────────────────────────────────────
def gradio_analyze_photo(image):
"""Analyze a single uploaded photo via GPU."""
if image is None:
return "Please upload an image."
from PIL import Image as PILImage
if not isinstance(image, PILImage.Image):
image = PILImage.fromarray(image)
ensure_init()
if not inference_engine.is_loaded:
inference_engine.load_model()
result = gpu_run_inference(image, SCENE_ANALYSIS_PROMPT)
return result
import json
import hashlib
import time
from PIL import Image
def create_case_fn(case_number, officer_name, location, incident_date, notes):
"""Create a new accident case."""
if not case_number or not case_number.strip():
return "❌ Case number is required.", list_cases_fn()
ensure_init()
try:
cid = run_async(db.create_case(
case_number=case_number.strip(),
officer_name=officer_name.strip() if officer_name else None,
location=location.strip() if location else None,
incident_date=incident_date if incident_date else None,
notes=notes.strip() if notes else None,
))
return f"βœ… Case **{case_number}** created (ID: {cid})", list_cases_fn()
except Exception as e:
return f"❌ {e}", list_cases_fn()
def list_cases_fn():
"""List all cases."""
ensure_init()
try:
cases = run_async(db.list_cases())
if not cases:
return []
rows = []
for c in cases:
photos = run_async(db.get_photos_by_case(c["id"]))
rows.append([
c["id"], c["case_number"],
c.get("officer_name", "β€”"), c.get("location", "β€”"),
c.get("incident_date", "β€”"), c["status"], len(photos),
])
return rows
except Exception:
return []
def delete_case_fn(case_id):
"""Delete a case."""
if not case_id:
return "❌ Enter a Case ID.", list_cases_fn()
ensure_init()
try:
run_async(db.delete_case(int(case_id)))
return f"βœ… Case {int(case_id)} deleted.", list_cases_fn()
except Exception as e:
return f"❌ {e}", list_cases_fn()
def upload_photos_fn(case_id, files):
"""Upload photos to a case."""
if not case_id:
return "❌ Enter a Case ID."
if not files:
return "❌ Select photos to upload."
ensure_init()
try:
case = run_async(db.get_case(int(case_id)))
if not case:
return f"❌ Case {int(case_id)} not found."
case_dir = settings.upload_path / f"case_{int(case_id)}"
case_dir.mkdir(parents=True, exist_ok=True)
count = 0
for fp in files:
with open(fp, "rb") as f:
content = f.read()
filename = Path(fp).name
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
if ext not in settings.allowed_extensions_list:
continue
fhash = hashlib.md5(content).hexdigest()[:12]
dest = case_dir / f"{fhash}_{filename}"
with open(dest, "wb") as f:
f.write(content)
w, h = None, None
try:
img = Image.open(dest)
w, h = img.size
except Exception:
pass
run_async(db.add_photo(
case_id=int(case_id), filename=filename,
filepath=str(dest), file_size=len(content),
width=w, height=h,
))
count += 1
return f"βœ… Uploaded {count} photo(s) to Case {int(case_id)}."
except Exception as e:
return f"❌ {e}"
def get_case_photos_fn(case_id):
"""Get photo gallery for a case."""
if not case_id:
return []
ensure_init()
try:
photos = run_async(db.get_photos_by_case(int(case_id)))
return [(p["filepath"], p["filename"]) for p in photos if Path(p["filepath"]).exists()]
except Exception:
return []
def run_analysis_fn(case_id, progress=gr.Progress()):
"""Run the full AI analysis pipeline (GPU-accelerated)."""
if not case_id:
return "❌ Enter a Case ID.", "", ""
ensure_init()
try:
case = run_async(db.get_case(int(case_id)))
if not case:
return "❌ Case not found.", "", ""
photos = run_async(db.get_photos_by_case(int(case_id)))
if not photos:
return "❌ No photos uploaded.", "", ""
except Exception as e:
return f"❌ {e}", "", ""
if not inference_engine.is_loaded:
inference_engine.load_model()
# Step 1: Analyze each photo
analysis_results = []
for i, photo in enumerate(photos):
progress((i + 1) / len(photos) * 0.5, desc=f"Analyzing photo {i+1}/{len(photos)}...")
try:
img = Image.open(photo["filepath"])
start = time.perf_counter()
raw = gpu_run_inference(img, SCENE_ANALYSIS_PROMPT)
elapsed_ms = (time.perf_counter() - start) * 1000
parsed = scene_analyzer._parse_analysis(raw)
run_async(db.add_scene_analysis(
photo_id=photo["id"], raw_analysis=raw,
vehicles_json=json.dumps(parsed.get("vehicles", [])) if parsed.get("vehicles") else None,
road_conditions_json=json.dumps(parsed.get("road_conditions", {})) if parsed.get("road_conditions") else None,
evidence_json=json.dumps(parsed.get("evidence", {})) if parsed.get("evidence") else None,
environmental_json=json.dumps(parsed.get("environmental", {})) if parsed.get("environmental") else None,
positions_json=json.dumps(parsed.get("positions", {})) if parsed.get("positions") else None,
model_id=settings.model_id, inference_time_ms=elapsed_ms,
))
analysis_results.append({"filename": photo["filename"], "analysis": raw, "time_ms": round(elapsed_ms)})
except Exception as e:
analysis_results.append({"filename": photo["filename"], "analysis": f"Error: {e}", "time_ms": 0})
# Identify parties
progress(0.55, desc="Identifying parties...")
all_analyses = run_async(db.get_analyses_by_case(int(case_id)))
parties_data = scene_analyzer._identify_parties(all_analyses)
run_async(db.clear_parties(int(case_id)))
for p in parties_data:
run_async(db.add_party(
case_id=int(case_id), label=p.get("label", "Unknown"),
vehicle_type=p.get("vehicle_type"), vehicle_color=p.get("vehicle_color"),
vehicle_description=p.get("description"),
))
# Step 2: Rule matching
progress(0.65, desc="Matching traffic rules...")
violations = run_async(rule_matcher.match_violations(int(case_id)))
# Step 3: Fault deduction
progress(0.8, desc="Deducing fault...")
fault_result = run_async(fault_deducer.deduce_fault(int(case_id)))
run_async(db.update_case_status(int(case_id), "complete"))
# Format output
total_time = sum(r["time_ms"] for r in analysis_results)
analysis_text = ""
for r in analysis_results:
analysis_text += f"### πŸ“· {r['filename']} ({r['time_ms']}ms)\n```\n{r['analysis']}\n```\n---\n\n"
violations_text = f"Found {len(violations)} violation(s):\n"
for v in violations:
violations_text += f"\nβ€’ **{v.get('rule_title', '?')}** ({v.get('severity', '?')}) β€” {v.get('confidence', 0):.0%}"
violations_text += f"\n\n### Fault: {fault_result.get('primary_fault_party', 'N/A')}"
violations_text += f"\nConfidence: {fault_result.get('overall_confidence', 0):.0%}"
violations_text += f"\n\n{fault_result.get('analysis_summary', '')}"
progress(1.0, desc="Complete!")
return f"βœ… Done! {len(photos)} photos in {total_time/1000:.1f}s", analysis_text, violations_text
def generate_report_fn(case_id):
"""Generate incident report."""
if not case_id:
return "❌ Enter a Case ID."
ensure_init()
try:
report = run_async(report_generator.generate_report(int(case_id)))
except Exception as e:
return f"❌ {e}"
if "error" in report:
return f"❌ {report['error']}"
c = report.get("case", {})
stats = report.get("statistics", {})
fa = report.get("fault_analysis", {})
md = f"""# πŸš” TraceScene Report
> Case: {c.get('case_number', 'β€”')} | Officer: {c.get('officer_name', 'β€”')}
> Location: {c.get('location', 'β€”')} | Date: {c.get('incident_date', 'β€”')}
*{report.get('disclaimer', '')}*
| Metric | Value |
|---|---|
| Photos | {stats.get('analyzed_photos', 0)} |
| Violations | {stats.get('total_violations', 0)} |
| Critical | {stats.get('critical_violations', 0)} |
| Parties | {stats.get('parties_identified', 0)} |
## Scene Summary
{report.get('scene_summary', 'N/A')}
## Violations
"""
for v in report.get("violations", {}).get("list", []):
md += f"- **{v.get('title', '?')}** [{v.get('severity', '?')}] β€” {v.get('party', '?')} ({v.get('confidence', 0):.0%})\n"
md += f"\n## Fault Analysis\n"
if fa.get("determined"):
md += f"**Primary Fault:** {fa.get('primary_fault_party', '?')}\n"
md += f"**Confidence:** {fa.get('overall_confidence', 0):.0%}\n"
md += f"\n{fa.get('probable_cause', '')}\n"
return md
def get_rules_fn():
"""Get traffic rules."""
ensure_init()
data = rule_loader.get_all_rules()
categories = data.get("categories", [])
if not categories:
return "No rules loaded."
md = "# πŸ“œ Traffic Rules\n\n"
for cat in categories:
md += f"## {cat.get('name', '?')} ({cat.get('rule_count', 0)})\n"
md += "| ID | Title | Severity | Weight |\n|---|---|---|---|\n"
for r in cat.get("rules", []):
md += f"| {r.get('id', '')} | {r.get('title', '')} | {r.get('severity', '')} | {r.get('fault_weight', '')} |\n"
md += "\n"
return md
# ── JSON API functions (for custom frontend via @gradio/client) ────────
def health_fn():
"""Return system health as JSON."""
ensure_init()
return json.dumps({
"status": "ok",
"model_loaded": inference_engine.is_loaded,
"model_id": settings.model_id if inference_engine.is_loaded else None,
"device": inference_engine._device if inference_engine.is_loaded else None,
"rules_loaded": len(rule_loader.get_all_rules().get("categories", [])),
})
def list_cases_json():
"""List cases as JSON."""
ensure_init()
cases = run_async(db.list_cases())
for c in cases:
photos = run_async(db.get_photos_by_case(c["id"]))
c["photo_count"] = len(photos)
return json.dumps({"cases": cases})
def get_case_json(case_id):
"""Get full case details as JSON."""
if not case_id:
return json.dumps({"error": "No case ID"})
ensure_init()
case = run_async(db.get_case(int(case_id)))
if not case:
return json.dumps({"error": f"Case {int(case_id)} not found"})
photos = run_async(db.get_photos_by_case(int(case_id)))
analyses = run_async(db.get_analyses_by_case(int(case_id)))
parties = run_async(db.get_parties_by_case(int(case_id)))
violations = run_async(db.get_violations_by_case(int(case_id)))
fault = run_async(db.get_fault_analysis(int(case_id)))
return json.dumps({
"case": case,
"photos": photos,
"analyses": analyses,
"parties": parties,
"violations": violations,
"fault_analysis": fault,
"stats": {
"total_photos": len(photos),
"analyzed_photos": len(analyses),
"violations_found": len(violations),
"parties_identified": len(parties),
},
})
def get_report_json(case_id):
"""Get report as JSON."""
if not case_id:
return json.dumps({"error": "No case ID"})
ensure_init()
report = run_async(report_generator.generate_report(int(case_id)))
return json.dumps(report)
def get_rules_json():
"""Get rules as JSON."""
ensure_init()
return json.dumps(rule_loader.get_all_rules())
# ── Build Gradio App ──────────────────────────────────────────────────
CUSTOM_CSS = """
.gradio-container { max-width: 1200px !important; }
footer { display: none !important; }
"""
with gr.Blocks(
title="TraceScene β€” AI Accident Analysis",
) as demo:
gr.Markdown("""
# πŸš” TraceScene
### AI-Powered Accident Scene Analysis
*GPU-accelerated inference via ZeroGPU (NVIDIA H200)*
---
""")
with gr.Tabs():
# Tab 1: Quick Analyze (single photo)
with gr.TabItem("⚑ Quick Analyze"):
gr.Markdown("Upload a photo for instant GPU-accelerated analysis.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload Accident Photo", type="pil")
quick_btn = gr.Button("πŸš€ Analyze with GPU", variant="primary")
with gr.Column():
quick_output = gr.Textbox(label="AI Analysis", lines=20)
quick_btn.click(fn=gradio_analyze_photo, inputs=[input_image], outputs=[quick_output], api_name="analyze_photo")
# Tab 2: Cases
with gr.TabItem("πŸ“‹ Cases"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Create Case")
cn = gr.Textbox(label="Case Number *", placeholder="ACC-2026-001")
on = gr.Textbox(label="Officer Name")
loc = gr.Textbox(label="Location")
dt = gr.Textbox(label="Incident Date", placeholder="YYYY-MM-DD")
nt = gr.Textbox(label="Notes", lines=2)
create_btn = gr.Button("Create Case", variant="primary")
create_status = gr.Markdown()
with gr.Column(scale=2):
gr.Markdown("### Existing Cases")
cases_tbl = gr.Dataframe(
headers=["ID", "Case #", "Officer", "Location", "Date", "Status", "Photos"],
interactive=False,
)
with gr.Row():
refresh_btn = gr.Button("πŸ”„ Refresh")
del_id = gr.Number(label="Case ID to Delete", precision=0)
del_btn = gr.Button("πŸ—‘οΈ Delete", variant="stop")
del_status = gr.Markdown()
create_btn.click(create_case_fn, inputs=[cn, on, loc, dt, nt], outputs=[create_status, cases_tbl], api_name="create_case")
refresh_btn.click(list_cases_fn, outputs=[cases_tbl], api_name="list_cases")
del_btn.click(delete_case_fn, inputs=[del_id], outputs=[del_status, cases_tbl], api_name="delete_case")
# Tab 3: Upload Photos
with gr.TabItem("πŸ“Έ Photos"):
with gr.Row():
with gr.Column(scale=1):
up_case = gr.Number(label="Case ID", precision=0)
up_files = gr.File(label="Select Photos", file_count="multiple", file_types=["image"])
up_btn = gr.Button("Upload", variant="primary")
up_status = gr.Markdown()
with gr.Column(scale=2):
pv_case = gr.Number(label="Preview Case ID", precision=0)
pv_btn = gr.Button("Load Photos")
gallery = gr.Gallery(label="Photos", columns=3)
up_btn.click(upload_photos_fn, inputs=[up_case, up_files], outputs=[up_status], api_name="upload_photos")
pv_btn.click(get_case_photos_fn, inputs=[pv_case], outputs=[gallery], api_name="get_case_photos")
# Tab 4: Run Analysis
with gr.TabItem("🧠 Analysis"):
gr.Markdown("""
### Full Analysis Pipeline (GPU-accelerated)
1. Scene Analysis β†’ 2. Rule Matching β†’ 3. Fault Deduction
""")
an_case = gr.Number(label="Case ID", precision=0)
an_btn = gr.Button("πŸš€ Run Full Analysis", variant="primary", size="lg")
an_status = gr.Markdown()
with gr.Accordion("Scene Details", open=False):
an_detail = gr.Markdown()
an_violations = gr.Markdown(label="Violations & Fault")
an_btn.click(run_analysis_fn, inputs=[an_case], outputs=[an_status, an_detail, an_violations], api_name="run_analysis")
# Tab 5: Report
with gr.TabItem("πŸ“„ Report"):
rp_case = gr.Number(label="Case ID", precision=0)
rp_btn = gr.Button("Generate Report", variant="primary")
rp_out = gr.Markdown()
rp_btn.click(generate_report_fn, inputs=[rp_case], outputs=[rp_out], api_name="generate_report")
# Tab 6: Rules
with gr.TabItem("πŸ“œ Rules"):
ru_btn = gr.Button("Load Traffic Rules")
ru_out = gr.Markdown()
ru_btn.click(get_rules_fn, outputs=[ru_out], api_name="get_rules")
# Tab 7: Chat Q&A
with gr.TabItem("πŸ’¬ Chat"):
gr.Markdown("### Case Q&A Chatbot\nAsk questions about logged cases, traffic rules, or insurance clauses.")
with gr.Row():
chat_case_id = gr.Number(label="Case ID (optional)", precision=0)
chat_load_btn = gr.Button("πŸ“‚ Load Case Context", variant="secondary")
chat_context_status = gr.Markdown(value="*No case loaded. You can still ask general traffic/insurance questions.*")
chatbot = gr.Chatbot(label="Conversation", height=400)
chat_input = gr.Textbox(label="Your Question", placeholder="e.g. What vehicles were involved? What rules were violated?", lines=2)
with gr.Row():
chat_send_btn = gr.Button("πŸ’¬ Send", variant="primary")
chat_clear_btn = gr.Button("πŸ—‘οΈ Clear")
# State for context
chat_system_ctx = gr.State(value="You are TraceScene AI assistant. You help insurers and investigating officers analyze accident cases, traffic rules, and insurance clauses. Answer concisely and accurately based on the context provided.")
def load_chat_context(case_id):
if not case_id:
default_ctx = "You are TraceScene AI assistant. You help insurers and investigating officers analyze accident cases, traffic rules, and insurance clauses. Answer concisely and accurately.\n\n"
# Load traffic rules as general context
ensure_init()
rules_data = rule_loader.get_all_rules()
rules_text = ""
for cat in rules_data.get("categories", []):
rules_text += f"\nCategory: {cat.get('name', '')}\n"
for r in cat.get("rules", []):
rules_text += f" - {r.get('id', '')}: {r.get('title', '')} (Severity: {r.get('severity', '')})\n"
ctx = default_ctx + "TRAFFIC RULES:\n" + rules_text
return ctx, "*General mode: traffic rules loaded. Ask any question!*"
ensure_init()
case = run_async(db.get_case(int(case_id)))
if not case:
return "", f"❌ Case {int(case_id)} not found."
analyses = run_async(db.get_analyses_by_case(int(case_id)))
parties = run_async(db.get_parties_by_case(int(case_id)))
violations = run_async(db.get_violations_by_case(int(case_id)))
fault = run_async(db.get_fault_analysis(int(case_id)))
rules_data = rule_loader.get_all_rules()
ctx = f"""You are TraceScene AI assistant analyzing Case #{case.get('case_number', '')}.
Location: {case.get('location', 'Unknown')}
Date: {case.get('incident_date', 'Unknown')}
Officer: {case.get('officer_name', 'Unknown')}
Status: {case.get('status', 'Unknown')}
SCENE ANALYSES:\n"""
for a in analyses:
ctx += f"\n--- Photo Analysis ---\n{a.get('raw_analysis', '')}\n"
if parties:
ctx += "\nPARTIES IDENTIFIED:\n"
for p in parties:
ctx += f" - {p.get('label', '')}: {p.get('vehicle_type', '')} {p.get('vehicle_color', '')} β€” {p.get('vehicle_description', '')}\n"
if violations:
ctx += "\nVIOLATIONS FOUND:\n"
for v in violations:
ctx += f" - {v.get('rule_title', '')} (Severity: {v.get('severity', '')}, Confidence: {v.get('confidence', 0):.0%})\n"
if fault:
ctx += f"\nFAULT ANALYSIS:\n Primary Fault: {fault.get('primary_fault_party', 'N/A')}\n Confidence: {fault.get('overall_confidence', 0):.0%}\n Summary: {fault.get('analysis_summary', '')}\n"
# Append traffic rules
rules_text = ""
for cat in rules_data.get("categories", []):
rules_text += f"\nCategory: {cat.get('name', '')}\n"
for r in cat.get("rules", []):
rules_text += f" - {r.get('id', '')}: {r.get('title', '')} (Severity: {r.get('severity', '')})\n"
ctx += "\nTRAFFIC RULES:\n" + rules_text
return ctx, f"βœ… Case **{case.get('case_number', '')}** loaded with {len(analyses)} analyses, {len(violations)} violations."
def chat_respond(user_message, history, system_ctx):
if not user_message or not user_message.strip():
return history, "", system_ctx
ensure_init()
if not inference_engine.is_loaded:
inference_engine.load_model()
try:
# Use the vision model's text generation capability for chat
chat_prompt = f"""You are TraceScene AI assistant helping with accident analysis.
CONTEXT:
{system_ctx}
USER QUESTION: {user_message.strip()}
Provide a concise, helpful answer based on the context above."""
# Create a small blank image for the vision model
from PIL import Image as PILImg
blank = PILImg.new('RGB', (64, 64), color=(0, 0, 0))
response = gpu_run_inference(blank, chat_prompt)
except Exception as e:
response = f"Error: {e}"
history = history or []
history.append((user_message.strip(), response))
return history, "", system_ctx
chat_load_btn.click(load_chat_context, inputs=[chat_case_id], outputs=[chat_system_ctx, chat_context_status])
chat_send_btn.click(chat_respond, inputs=[chat_input, chatbot, chat_system_ctx], outputs=[chatbot, chat_input, chat_system_ctx])
chat_input.submit(chat_respond, inputs=[chat_input, chatbot, chat_system_ctx], outputs=[chatbot, chat_input, chat_system_ctx])
chat_clear_btn.click(lambda: ([], ""), outputs=[chatbot, chat_input])
# Tab 8: 2D Animation
with gr.TabItem("🎬 Animation"):
gr.Markdown("### 2D Accident Simulation\nGenerates a bird's-eye view animation from the scene analysis data.")
anim_case_id = gr.Number(label="Case ID", precision=0)
anim_btn = gr.Button("🎬 Generate Animation", variant="primary")
anim_output = gr.HTML(label="Animation")
def generate_animation_fn(case_id):
if not case_id:
return "<p style='color:red;'>Enter a Case ID.</p>"
ensure_init()
analyses = run_async(db.get_analyses_by_case(int(case_id)))
if not analyses:
return "<p style='color:red;'>No analyses found. Run analysis first.</p>"
# Parse scene details from the first analysis
raw = analyses[0].get("raw_analysis", "")
def extract_field(text, field):
import re
pattern = rf"{re.escape(field)}:\s*(.+)"
m = re.search(pattern, text, re.IGNORECASE)
return m.group(1).strip() if m else "Unknown"
road_type = extract_field(raw, "Road Type")
num_vehicles = extract_field(raw, "Vehicles Involved")
v1_pos = extract_field(raw, "Vehicle 1 Position")
v1_tyre = extract_field(raw, "Vehicle 1 Tyre Direction")
impact = extract_field(raw, "Area of Impact")
category = extract_field(raw, "Accident Category")
v1_make = extract_field(raw, "Vehicle 1 Make/Model")
# Check for Vehicle 2
v2_pos = extract_field(raw, "Vehicle 2 Position")
v2_tyre = extract_field(raw, "Vehicle 2 Tyre Direction")
v2_make = extract_field(raw, "Vehicle 2 Make/Model")
has_v2 = v2_make != "Unknown"
# Determine colors from extracted make
import re as re_mod
def extract_color(make_str):
colors = ["Red", "Blue", "White", "Black", "Silver", "Grey", "Green", "Yellow", "Brown", "Orange"]
for c in colors:
if c.lower() in make_str.lower():
return c.lower()
return "#3b82f6"
v1_color = extract_color(v1_make)
v2_color = extract_color(v2_make) if has_v2 else "#ef4444"
# Severity affects animation speed
speed_map = {"mild": 1.5, "medium": 2.5, "critical": 4.0}
anim_speed = speed_map.get(category.lower(), 2.5)
# Road layout
road_is_intersection = "intersection" in road_type.lower()
road_is_highway = "highway" in road_type.lower()
num_v = 1
try:
num_v = int(num_vehicles)
except:
pass
# Unique ID to force Gradio to re-render on each click (enables replay)
import random
uid = random.randint(10000, 99999)
# Determine animation duration based on severity
dur = "3s" if category.lower() == "mild" else "2s" if category.lower() == "medium" else "1.5s"
sev_color = "#22c55e" if category.lower() == "mild" else "#f59e0b" if category.lower() == "medium" else "#ef4444"
# Build SVG road
if road_is_intersection:
road_svg = '''
<rect x="0" y="160" width="700" height="100" fill="#555"/>
<rect x="300" y="0" width="100" height="420" fill="#555"/>
<line x1="0" y1="210" x2="300" y2="210" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/>
<line x1="400" y1="210" x2="700" y2="210" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/>
<line x1="350" y1="0" x2="350" y2="160" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/>
<line x1="350" y1="260" x2="350" y2="420" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/>
'''
else:
road_svg = '''
<rect x="0" y="150" width="700" height="120" fill="#555" rx="2"/>
<line x1="0" y1="210" x2="700" y2="210" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/>
<line x1="0" y1="150" x2="700" y2="150" stroke="white" stroke-width="2"/>
<line x1="0" y1="270" x2="700" y2="270" stroke="white" stroke-width="2"/>
'''
# Vehicle 2 SVG (if present)
v2_svg = ""
if has_v2:
if road_is_intersection:
v2_svg = f'''<g>
<animateTransform attributeName="transform" type="translate" from="0,0" to="0,135" dur="{dur}" fill="freeze"/>
<rect x="325" y="60" width="50" height="26" rx="5" fill="{v2_color}" stroke="#fff" stroke-width="1"/>
<text x="350" y="78" fill="white" font-size="10" font-weight="bold" text-anchor="middle">V2</text>
</g>'''
else:
v2_svg = f'''<g>
<animateTransform attributeName="transform" type="translate" from="0,0" to="-200,0" dur="{dur}" fill="freeze"/>
<rect x="560" y="215" width="50" height="26" rx="5" fill="{v2_color}" stroke="#fff" stroke-width="1"/>
<text x="585" y="233" fill="white" font-size="10" font-weight="bold" text-anchor="middle">V2</text>
</g>'''
html = f'''
<div style="text-align:center; font-family: Inter, Arial, sans-serif;">
<svg id="anim_{uid}" width="700" height="420" viewBox="0 0 700 420" xmlns="http://www.w3.org/2000/svg" style="border:1px solid #444; border-radius:10px; background:#1a1a2e;">
<defs>
<radialGradient id="glow_{uid}" cx="50%" cy="50%" r="50%">
<stop offset="0%" stop-color="#fbbf24" stop-opacity="0.8"/>
<stop offset="100%" stop-color="#fbbf24" stop-opacity="0"/>
</radialGradient>
</defs>
{road_svg}
<!-- Vehicle 1 -->
<g>
<animateTransform attributeName="transform" type="translate" from="0,0" to="200,0" dur="{dur}" fill="freeze"/>
<rect x="80" y="190" width="50" height="26" rx="5" fill="{v1_color}" stroke="#fff" stroke-width="1"/>
<text x="105" y="207" fill="white" font-size="10" font-weight="bold" text-anchor="middle">V1</text>
</g>
{v2_svg}
<!-- Impact flash -->
<circle cx="340" cy="210" r="0" fill="url(#glow_{uid})">
<animate attributeName="r" values="0;0;0;0;0;0;0;45;55;0" dur="{dur}" fill="freeze"/>
<animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0.5;0" dur="{dur}" fill="freeze"/>
</circle>
<!-- Debris -->
<circle cx="340" cy="210" r="3" fill="#fbbf24" opacity="0">
<animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/>
<animate attributeName="cx" values="340;340;340;340;340;340;340;310;290" dur="{dur}" fill="freeze"/>
<animate attributeName="cy" values="210;210;210;210;210;210;210;185;170" dur="{dur}" fill="freeze"/>
</circle>
<circle cx="340" cy="210" r="2" fill="#ef4444" opacity="0">
<animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/>
<animate attributeName="cx" values="340;340;340;340;340;340;340;370;395" dur="{dur}" fill="freeze"/>
<animate attributeName="cy" values="210;210;210;210;210;210;210;190;175" dur="{dur}" fill="freeze"/>
</circle>
<circle cx="340" cy="210" r="3" fill="#e2e8f0" opacity="0">
<animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/>
<animate attributeName="cx" values="340;340;340;340;340;340;340;320;305" dur="{dur}" fill="freeze"/>
<animate attributeName="cy" values="210;210;210;210;210;210;210;235;255" dur="{dur}" fill="freeze"/>
</circle>
<circle cx="340" cy="210" r="2" fill="#f97316" opacity="0">
<animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/>
<animate attributeName="cx" values="340;340;340;340;340;340;340;365;385" dur="{dur}" fill="freeze"/>
<animate attributeName="cy" values="210;210;210;210;210;210;210;230;250" dur="{dur}" fill="freeze"/>
</circle>
<!-- Collision label -->
<text x="350" y="145" fill="#ef4444" font-size="18" font-weight="bold" text-anchor="middle" opacity="0" font-family="Inter, Arial, sans-serif">
COLLISION
<animate attributeName="opacity" values="0;0;0;0;0;0;0;1;1" dur="{dur}" fill="freeze"/>
</text>
<!-- HUD -->
<rect x="10" y="350" width="680" height="60" rx="8" fill="rgba(0,0,0,0.6)"/>
<text x="20" y="375" fill="#e2e8f0" font-size="12" font-family="Inter, Arial, sans-serif">{v1_make[:35]}</text>
<text x="20" y="398" fill="#e2e8f0" font-size="12" font-family="Inter, Arial, sans-serif">{"" if not has_v2 else v2_make[:35]}{"Single vehicle accident" if not has_v2 else ""}</text>
<text x="680" y="375" fill="{sev_color}" font-size="14" font-weight="bold" text-anchor="end" font-family="Inter, Arial, sans-serif">{category.upper()}</text>
<text x="680" y="398" fill="#94a3b8" font-size="11" text-anchor="end" font-family="Inter, Arial, sans-serif">Impact: {impact} | Road: {road_type}</text>
</svg>
<div style="margin-top:8px; color:#94a3b8; font-size:12px;">
Vehicles: {num_v} | Animation auto-plays on load
</div>
</div>
'''
return html
anim_btn.click(generate_animation_fn, inputs=[anim_case_id], outputs=[anim_output])
anim_btn.click(generate_animation_fn, inputs=[anim_case_id], outputs=[anim_output])
# Hidden API-only endpoints (for @gradio/client from custom frontend)
with gr.TabItem("πŸ”Œ API", visible=False):
api_health_btn = gr.Button("health")
api_health_out = gr.Textbox()
api_health_btn.click(health_fn, outputs=[api_health_out], api_name="health")
api_cases_btn = gr.Button("list_cases_json")
api_cases_out = gr.Textbox()
api_cases_btn.click(list_cases_json, outputs=[api_cases_out], api_name="list_cases_json")
api_case_id = gr.Number(precision=0)
api_case_btn = gr.Button("get_case")
api_case_out = gr.Textbox()
api_case_btn.click(get_case_json, inputs=[api_case_id], outputs=[api_case_out], api_name="get_case")
api_report_id = gr.Number(precision=0)
api_report_btn = gr.Button("get_report")
api_report_out = gr.Textbox()
api_report_btn.click(get_report_json, inputs=[api_report_id], outputs=[api_report_out], api_name="get_report_json")
api_rules_btn = gr.Button("get_rules_json")
api_rules_out = gr.Textbox()
api_rules_btn.click(get_rules_json, outputs=[api_rules_out], api_name="get_rules_json")
gr.Markdown("---\n*TraceScene β€” Built by Siddharth Ravikumar | tracescene@zohomail.ae*")
# ── Launch (required for ZeroGPU detection) ────────────────────────────
demo.launch(
server_name="0.0.0.0",
server_port=7860,
css=CUSTOM_CSS,
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
),
)