Spaces:
Running on Zero
Running on Zero
Siddharth Ravikumar
fix: refactor image loading to match TraceSceneUI structure using root-relative paths and direct directory mounting
e08a816 | """ | |
| TraceScene β Gradio ZeroGPU Application | |
| Serves the custom TraceScene frontend + REST API with GPU-accelerated inference. | |
| Architecture: | |
| - Gradio demo at / (primary β required for ZeroGPU) | |
| - Custom FastAPI routes added to Gradio's internal app for REST API | |
| - Custom HTML/CSS/JS frontend served alongside | |
| - @spaces.GPU wraps inference for dynamic GPU allocation | |
| """ | |
| import os | |
| from pathlib import Path | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| # ββ Backend Imports ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from backend.app.config import settings | |
| from backend.app.db.database import db | |
| from backend.app.core.inference import inference_engine, chat_engine, SCENE_ANALYSIS_PROMPT | |
| from backend.app.core.scene_analyzer import SceneAnalyzer | |
| from backend.app.core.rule_matcher import RuleMatcher | |
| from backend.app.core.fault_deducer import FaultDeducer | |
| from backend.app.core.report_generator import ReportGenerator | |
| from backend.app.rules.rule_loader import rule_loader | |
| from backend.app.utils.logger import get_logger | |
| from backend.app.api.routes import router | |
| logger = get_logger("app") | |
| scene_analyzer = SceneAnalyzer() | |
| rule_matcher = RuleMatcher() | |
| fault_deducer = FaultDeducer() | |
| report_generator = ReportGenerator() | |
| from backend.app.core.reference_data import REFERENCE_CASES | |
| # ββ ZeroGPU: Top-level decorated function ββββββββββββββββββββββββββββββ | |
| # This MUST be a top-level function wired to a Gradio event handler. | |
| _original_run_inference = inference_engine._run_inference # bound method | |
| def gpu_run_inference(image, prompt): | |
| """GPU-accelerated inference β ZeroGPU allocates GPU for this call.""" | |
| return _original_run_inference(image, prompt) | |
| # Monkey-patch so the entire pipeline uses GPU | |
| inference_engine._run_inference = gpu_run_inference | |
| _original_chat = chat_engine.chat | |
| def gpu_run_chat(system_context: str, user_message: str): | |
| """GPU-accelerated chat inference""" | |
| return _original_chat(system_context, user_message) | |
| chat_engine.chat = gpu_run_chat | |
| # ββ Async helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_async(coro): | |
| """Run async coroutine from sync Gradio context.""" | |
| import asyncio | |
| try: | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| import concurrent.futures | |
| with concurrent.futures.ThreadPoolExecutor() as pool: | |
| return pool.submit(asyncio.run, coro).result() | |
| return loop.run_until_complete(coro) | |
| except RuntimeError: | |
| return asyncio.run(coro) | |
| # ββ Initialize backend ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _initialized = False | |
| async def _ensure_init(): | |
| global _initialized | |
| if _initialized: | |
| return | |
| await db.connect() | |
| rule_loader.load_rules() | |
| try: | |
| inference_engine.load_model() | |
| except Exception as e: | |
| logger.error(f"Vision model load failed: {e}") | |
| _initialized = True | |
| def ensure_init(): | |
| run_async(_ensure_init()) | |
| # ββ Gradio Handlers βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def gradio_analyze_photo(image): | |
| """Analyze a single uploaded photo via GPU.""" | |
| if image is None: | |
| return "Please upload an image." | |
| from PIL import Image as PILImage | |
| if not isinstance(image, PILImage.Image): | |
| image = PILImage.fromarray(image) | |
| ensure_init() | |
| if not inference_engine.is_loaded: | |
| inference_engine.load_model() | |
| result = gpu_run_inference(image, SCENE_ANALYSIS_PROMPT) | |
| return result | |
| import json | |
| import hashlib | |
| import time | |
| from PIL import Image | |
| def create_case_fn(case_number, officer_name, location, incident_date, notes): | |
| """Create a new accident case.""" | |
| if not case_number or not case_number.strip(): | |
| return "β Case number is required.", list_cases_fn() | |
| ensure_init() | |
| try: | |
| cid = run_async(db.create_case( | |
| case_number=case_number.strip(), | |
| officer_name=officer_name.strip() if officer_name else None, | |
| location=location.strip() if location else None, | |
| incident_date=incident_date if incident_date else None, | |
| notes=notes.strip() if notes else None, | |
| )) | |
| return f"β Case **{case_number}** created (ID: {cid})", list_cases_fn() | |
| except Exception as e: | |
| return f"β {e}", list_cases_fn() | |
| def list_cases_fn(): | |
| """List all cases.""" | |
| ensure_init() | |
| try: | |
| cases = run_async(db.list_cases()) | |
| if not cases: | |
| return [] | |
| rows = [] | |
| for c in cases: | |
| photos = run_async(db.get_photos_by_case(c["id"])) | |
| rows.append([ | |
| c["id"], c["case_number"], | |
| c.get("officer_name", "β"), c.get("location", "β"), | |
| c.get("incident_date", "β"), c["status"], len(photos), | |
| ]) | |
| return rows | |
| except Exception: | |
| return [] | |
| def delete_case_fn(case_id): | |
| """Delete a case.""" | |
| if not case_id: | |
| return "β Enter a Case ID.", list_cases_fn() | |
| ensure_init() | |
| try: | |
| run_async(db.delete_case(int(case_id))) | |
| return f"β Case {int(case_id)} deleted.", list_cases_fn() | |
| except Exception as e: | |
| return f"β {e}", list_cases_fn() | |
| def upload_photos_fn(case_id, files): | |
| """Upload photos to a case.""" | |
| if not case_id: | |
| return "β Enter a Case ID." | |
| if not files: | |
| return "β Select photos to upload." | |
| ensure_init() | |
| try: | |
| case = run_async(db.get_case(int(case_id))) | |
| if not case: | |
| return f"β Case {int(case_id)} not found." | |
| case_dir = settings.upload_path / f"case_{int(case_id)}" | |
| case_dir.mkdir(parents=True, exist_ok=True) | |
| count = 0 | |
| for fp in files: | |
| with open(fp, "rb") as f: | |
| content = f.read() | |
| filename = Path(fp).name | |
| ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else "" | |
| if ext not in settings.allowed_extensions_list: | |
| continue | |
| fhash = hashlib.md5(content).hexdigest()[:12] | |
| dest = case_dir / f"{fhash}_{filename}" | |
| with open(dest, "wb") as f: | |
| f.write(content) | |
| w, h = None, None | |
| try: | |
| img = Image.open(dest) | |
| w, h = img.size | |
| except Exception: | |
| pass | |
| run_async(db.add_photo( | |
| case_id=int(case_id), filename=filename, | |
| filepath=str(dest), file_size=len(content), | |
| width=w, height=h, | |
| )) | |
| count += 1 | |
| return f"β Uploaded {count} photo(s) to Case {int(case_id)}." | |
| except Exception as e: | |
| return f"β {e}" | |
| def get_case_photos_fn(case_id): | |
| """Get photo gallery for a case.""" | |
| if not case_id: | |
| return [] | |
| ensure_init() | |
| try: | |
| photos = run_async(db.get_photos_by_case(int(case_id))) | |
| if not photos: | |
| # Check reference cases | |
| ref = REFERENCE_CASES.get(int(case_id)) | |
| if ref: | |
| return [(p["filepath"], p["filename"]) for p in ref["photos"]] | |
| return [(p["filepath"], p["filename"]) for p in photos if Path(p["filepath"]).exists()] | |
| except Exception: | |
| return [] | |
| def run_analysis_fn(case_id, progress=gr.Progress()): | |
| """Run the full AI analysis pipeline (GPU-accelerated).""" | |
| import traceback | |
| try: | |
| if not case_id: | |
| return "β Enter a Case ID.", "", "" | |
| ensure_init() | |
| case = run_async(db.get_case(int(case_id))) | |
| if not case: | |
| return "β Case not found.", "", "" | |
| photos = run_async(db.get_photos_by_case(int(case_id))) | |
| if not photos: | |
| return "β No photos uploaded.", "", "" | |
| if not inference_engine.is_loaded: | |
| inference_engine.load_model() | |
| # Step 1: Analyze each photo | |
| analysis_results = [] | |
| for i, photo in enumerate(photos): | |
| progress((i + 1) / len(photos) * 0.5, desc=f"Analyzing photo {i+1}/{len(photos)}...") | |
| try: | |
| img = Image.open(photo["filepath"]) | |
| start = time.perf_counter() | |
| raw = gpu_run_inference(img, SCENE_ANALYSIS_PROMPT) | |
| elapsed_ms = (time.perf_counter() - start) * 1000 | |
| parsed = scene_analyzer._parse_analysis(raw) | |
| run_async(db.add_scene_analysis( | |
| photo_id=photo["id"], raw_analysis=raw, | |
| vehicles_json=json.dumps(parsed.get("vehicles", [])) if parsed.get("vehicles") else None, | |
| road_conditions_json=json.dumps(parsed.get("road_conditions", {})) if parsed.get("road_conditions") else None, | |
| evidence_json=json.dumps(parsed.get("evidence", {})) if parsed.get("evidence") else None, | |
| environmental_json=json.dumps(parsed.get("environmental", {})) if parsed.get("environmental") else None, | |
| positions_json=json.dumps(parsed.get("positions", {})) if parsed.get("positions") else None, | |
| model_id=settings.model_id, inference_time_ms=elapsed_ms, | |
| )) | |
| analysis_results.append({"filename": photo["filename"], "analysis": raw, "time_ms": round(elapsed_ms)}) | |
| except Exception as e: | |
| err_msg = f"Error: {e}" | |
| run_async(db.add_scene_analysis( | |
| photo_id=photo["id"], | |
| raw_analysis=err_msg, | |
| model_id=settings.model_id, | |
| inference_time_ms=0, | |
| )) | |
| analysis_results.append({"filename": photo["filename"], "analysis": err_msg, "time_ms": 0}) | |
| # Identify parties | |
| progress(0.55, desc="Identifying parties...") | |
| all_analyses = run_async(db.get_analyses_by_case(int(case_id))) | |
| parties_data = scene_analyzer._identify_parties(all_analyses) | |
| run_async(db.clear_parties(int(case_id))) | |
| for p in parties_data: | |
| run_async(db.add_party( | |
| case_id=int(case_id), label=p.get("label", "Unknown"), | |
| vehicle_type=p.get("vehicle_type"), vehicle_color=p.get("vehicle_color"), | |
| vehicle_description=p.get("description"), | |
| )) | |
| # Step 2: Rule matching | |
| progress(0.65, desc="Matching traffic rules...") | |
| violations = run_async(rule_matcher.match_violations(int(case_id))) | |
| # Step 3: Fault deduction | |
| progress(0.8, desc="Deducing fault...") | |
| fault_result = run_async(fault_deducer.deduce_fault(int(case_id))) | |
| run_async(db.update_case_status(int(case_id), "complete")) | |
| # Format output | |
| total_time = sum(r["time_ms"] for r in analysis_results) | |
| analysis_text = "" | |
| for r in analysis_results: | |
| analysis_text += f"### π· {r['filename']} ({r['time_ms']}ms)\n```\n{r['analysis']}\n```\n---\n\n" | |
| violations_text = f"Found {len(violations)} violation(s):\n" | |
| for v in violations: | |
| violations_text += f"\nβ’ **{v.get('rule_title', '?')}** ({v.get('severity', '?')}) β {v.get('confidence', 0):.0%}" | |
| violations_text += f"\n\n### Fault: {fault_result.get('primary_fault_party', 'N/A')}" | |
| violations_text += f"\nConfidence: {fault_result.get('overall_confidence', 0):.0%}" | |
| violations_text += f"\n\n{fault_result.get('analysis_summary', '')}" | |
| progress(1.0, desc="Complete!") | |
| return f"β Done! {len(photos)} photos in {total_time/1000:.1f}s", analysis_text, violations_text | |
| except Exception as e: | |
| import traceback | |
| return f"β Python Error: {e}", traceback.format_exc(), "" | |
| def generate_report_fn(case_id): | |
| """Generate incident report.""" | |
| if not case_id: | |
| return "β Enter a Case ID." | |
| ensure_init() | |
| try: | |
| report = run_async(report_generator.generate_report(int(case_id))) | |
| except Exception as e: | |
| return f"β {e}" | |
| if "error" in report: | |
| return f"β {report['error']}" | |
| c = report.get("case", {}) | |
| stats = report.get("statistics", {}) | |
| fa = report.get("fault_analysis", {}) | |
| md = f"""# π TraceScene Report | |
| > Case: {c.get('case_number', 'β')} | Officer: {c.get('officer_name', 'β')} | |
| > Location: {c.get('location', 'β')} | Date: {c.get('incident_date', 'β')} | |
| *{report.get('disclaimer', '')}* | |
| | Metric | Value | | |
| |---|---| | |
| | Photos | {stats.get('analyzed_photos', 0)} | | |
| | Violations | {stats.get('total_violations', 0)} | | |
| | Critical | {stats.get('critical_violations', 0)} | | |
| | Parties | {stats.get('parties_identified', 0)} | | |
| ## Scene Summary | |
| {report.get('scene_summary', 'N/A')} | |
| ## Violations | |
| """ | |
| for v in report.get("violations", {}).get("list", []): | |
| md += f"- **{v.get('title', '?')}** [{v.get('severity', '?')}] β {v.get('party', '?')} ({v.get('confidence', 0):.0%})\n" | |
| md += f"\n## Fault Analysis\n" | |
| if fa.get("determined"): | |
| md += f"**Primary Fault:** {fa.get('primary_fault_party', '?')}\n" | |
| md += f"**Confidence:** {fa.get('overall_confidence', 0):.0%}\n" | |
| md += f"\n{fa.get('probable_cause', '')}\n" | |
| return md | |
| def get_rules_fn(): | |
| """Get traffic rules.""" | |
| ensure_init() | |
| data = rule_loader.get_all_rules() | |
| categories = data.get("categories", []) | |
| if not categories: | |
| return "No rules loaded." | |
| md = "# π Traffic Rules\n\n" | |
| for cat in categories: | |
| md += f"## {cat.get('name', '?')} ({cat.get('rule_count', 0)})\n" | |
| md += "| ID | Title | Severity | Weight |\n|---|---|---|---|\n" | |
| for r in cat.get("rules", []): | |
| md += f"| {r.get('id', '')} | {r.get('title', '')} | {r.get('severity', '')} | {r.get('fault_weight', '')} |\n" | |
| md += "\n" | |
| return md | |
| # ββ JSON API functions (for custom frontend via @gradio/client) ββββββββ | |
| def health_fn(): | |
| """Return system health as JSON.""" | |
| ensure_init() | |
| return json.dumps({ | |
| "status": "ok", | |
| "model_loaded": inference_engine.is_loaded, | |
| "model_id": settings.model_id if inference_engine.is_loaded else None, | |
| "device": inference_engine._device if inference_engine.is_loaded else None, | |
| "rules_loaded": len(rule_loader.get_all_rules()), | |
| }) | |
| def list_cases_json(): | |
| """List cases as JSON, including reference cases.""" | |
| ensure_init() | |
| cases = run_async(db.list_cases()) | |
| for c in cases: | |
| photos = run_async(db.get_photos_by_case(c["id"])) | |
| c["photo_count"] = len(photos) | |
| c["is_reference"] = False | |
| # Add reference cases | |
| ref_list = [v["case"] for v in REFERENCE_CASES.values()] | |
| cases = ref_list + cases | |
| return json.dumps({"cases": cases}) | |
| def get_case_json(case_id): | |
| """Get full case details as JSON, handling reference cases.""" | |
| if not case_id: | |
| return json.dumps({"error": "No case ID"}) | |
| # Check reference cases first | |
| ref = REFERENCE_CASES.get(int(case_id)) | |
| if ref: | |
| data = ref.copy() | |
| data["stats"] = { | |
| "total_photos": len(data["photos"]), | |
| "analyzed_photos": len(data["analyses"]), | |
| "violations_found": len(data["violations"]), | |
| "parties_identified": len(data["parties"]), | |
| } | |
| return json.dumps(data) | |
| ensure_init() | |
| case = run_async(db.get_case(int(case_id))) | |
| if not case: | |
| return json.dumps({"error": f"Case {int(case_id)} not found"}) | |
| photos = run_async(db.get_photos_by_case(int(case_id))) | |
| analyses = run_async(db.get_analyses_by_case(int(case_id))) | |
| parties = run_async(db.get_parties_by_case(int(case_id))) | |
| violations = run_async(db.get_violations_by_case(int(case_id))) | |
| fault = run_async(db.get_fault_analysis(int(case_id))) | |
| case_dict = dict(case) | |
| case_dict["is_reference"] = False | |
| return json.dumps({ | |
| "case": case_dict, | |
| "photos": photos, | |
| "analyses": analyses, | |
| "parties": parties, | |
| "violations": violations, | |
| "fault_analysis": fault, | |
| "stats": { | |
| "total_photos": len(photos), | |
| "analyzed_photos": len(analyses), | |
| "violations_found": len(violations), | |
| "parties_identified": len(parties), | |
| }, | |
| }) | |
| def get_report_json(case_id): | |
| """Get report as JSON.""" | |
| if not case_id: | |
| return json.dumps({"error": "No case ID"}) | |
| ensure_init() | |
| report = run_async(report_generator.generate_report(int(case_id))) | |
| return json.dumps(report) | |
| def get_rules_json(): | |
| """Get rules as JSON.""" | |
| ensure_init() | |
| return json.dumps(rule_loader.get_all_rules()) | |
| def load_chat_context(case_id): | |
| if not case_id: | |
| default_ctx = "You are TraceScene AI assistant. You help insurers and investigating officers analyze accident cases, traffic rules, and insurance clauses. Answer concisely and accurately.\n\n" | |
| # Load traffic rules as general context | |
| ensure_init() | |
| rules_data = rule_loader.get_all_rules() | |
| rules_text = "" | |
| for cat in rules_data.get("categories", []): | |
| rules_text += f"\nCategory: {cat.get('name', '')}\n" | |
| for r in cat.get("rules", []): | |
| rules_text += f" - {r.get('id', '')}: {r.get('title', '')} (Severity: {r.get('severity', '')})\n" | |
| ctx = default_ctx + "TRAFFIC RULES:\n" + rules_text | |
| return ctx, "*General mode: traffic rules loaded. Ask any question!*" | |
| ensure_init() | |
| case = run_async(db.get_case(int(case_id))) | |
| if not case: | |
| return "", f"β Case {int(case_id)} not found." | |
| analyses = run_async(db.get_analyses_by_case(int(case_id))) | |
| parties = run_async(db.get_parties_by_case(int(case_id))) | |
| violations = run_async(db.get_violations_by_case(int(case_id))) | |
| fault = run_async(db.get_fault_analysis(int(case_id))) | |
| rules_data = rule_loader.get_all_rules() | |
| ctx = f"""You are TraceScene AI assistant analyzing Case #{case.get('case_number', '')}. | |
| Location: {case.get('location', 'Unknown')} | |
| Date: {case.get('incident_date', 'Unknown')} | |
| Officer: {case.get('officer_name', 'Unknown')} | |
| Status: {case.get('status', 'Unknown')} | |
| SCENE ANALYSES:\n""" | |
| for a in analyses: | |
| ctx += f"\n--- Photo Analysis ---\n{a.get('raw_analysis', '')}\n" | |
| if parties: | |
| ctx += "\nPARTIES IDENTIFIED:\n" | |
| for p in parties: | |
| ctx += f" - {p.get('label', '')}: {p.get('vehicle_type', '')} {p.get('vehicle_color', '')} β {p.get('vehicle_description', '')}\n" | |
| if violations: | |
| ctx += "\nVIOLATIONS FOUND:\n" | |
| for v in violations: | |
| ctx += f" - {v.get('rule_title', '')} (Severity: {v.get('severity', '')}, Confidence: {v.get('confidence', 0):.0%})\n" | |
| if fault: | |
| ctx += f"\nFAULT ANALYSIS:\n Primary Fault: {fault.get('primary_fault_party', 'N/A')}\n Confidence: {fault.get('overall_confidence', 0):.0%}\n Summary: {fault.get('analysis_summary', '')}\n" | |
| # Append traffic rules | |
| rules_text = "" | |
| for cat in rules_data.get("categories", []): | |
| rules_text += f"\nCategory: {cat.get('name', '')}\n" | |
| for r in cat.get("rules", []): | |
| rules_text += f" - {r.get('id', '')}: {r.get('title', '')} (Severity: {r.get('severity', '')})\n" | |
| ctx += "\nTRAFFIC RULES:\n" + rules_text | |
| return ctx, f"β Case **{case.get('case_number', '')}** loaded with {len(analyses)} analyses, {len(violations)} violations." | |
| def chat_respond(user_message, history, system_ctx): | |
| if not user_message or not user_message.strip(): | |
| return history, "", system_ctx | |
| ensure_init() | |
| if not chat_engine.is_loaded: | |
| chat_engine.load_model() | |
| try: | |
| response = gpu_run_chat(system_ctx, user_message.strip()) | |
| except Exception as e: | |
| response = f"Error: {e}" | |
| history = history or [] | |
| history.append((user_message.strip(), response)) | |
| return history, "", system_ctx | |
| def generate_animation_fn(case_id): | |
| if not case_id: | |
| return "<p style='color:red;'>Enter a Case ID.</p>" | |
| ensure_init() | |
| analyses = run_async(db.get_analyses_by_case(int(case_id))) | |
| if not analyses: | |
| return "<p style='color:red;'>No analyses found. Run analysis first.</p>" | |
| # Parse scene details from the first analysis | |
| raw = analyses[0].get("raw_analysis", "") | |
| def extract_field(text, field): | |
| import re | |
| pattern = rf"{re.escape(field)}:\s*(.+)" | |
| m = re.search(pattern, text, re.IGNORECASE) | |
| return m.group(1).strip() if m else "Unknown" | |
| road_type = extract_field(raw, "Road Type") | |
| num_vehicles = extract_field(raw, "Vehicles Involved") | |
| v1_pos = extract_field(raw, "Vehicle 1 Position") | |
| v1_tyre = extract_field(raw, "Vehicle 1 Tyre Direction") | |
| impact = extract_field(raw, "Area of Impact") | |
| category = extract_field(raw, "Accident Category") | |
| v1_make = extract_field(raw, "Vehicle 1 Make/Model") | |
| # Check for Vehicle 2 | |
| v2_pos = extract_field(raw, "Vehicle 2 Position") | |
| v2_tyre = extract_field(raw, "Vehicle 2 Tyre Direction") | |
| v2_make = extract_field(raw, "Vehicle 2 Make/Model") | |
| has_v2 = v2_make != "Unknown" | |
| # Determine colors from extracted make | |
| import re as re_mod | |
| def extract_color(make_str): | |
| colors = ["Red", "Blue", "White", "Black", "Silver", "Grey", "Green", "Yellow", "Brown", "Orange"] | |
| for c in colors: | |
| if c.lower() in make_str.lower(): | |
| return c.lower() | |
| return "#3b82f6" | |
| v1_color = extract_color(v1_make) | |
| v2_color = extract_color(v2_make) if has_v2 else "#ef4444" | |
| # Severity affects animation speed | |
| speed_map = {"mild": 1.5, "medium": 2.5, "critical": 4.0} | |
| anim_speed = speed_map.get(category.lower(), 2.5) | |
| # Road layout | |
| road_is_intersection = "intersection" in road_type.lower() | |
| road_is_highway = "highway" in road_type.lower() | |
| num_v = 1 | |
| try: | |
| num_v = int(num_vehicles) | |
| except: | |
| pass | |
| # Unique ID to force Gradio to re-render on each click (enables replay) | |
| import random | |
| uid = random.randint(10000, 99999) | |
| # Determine animation duration based on severity | |
| dur = "3s" if category.lower() == "mild" else "2s" if category.lower() == "medium" else "1.5s" | |
| sev_color = "#22c55e" if category.lower() == "mild" else "#f59e0b" if category.lower() == "medium" else "#ef4444" | |
| # Build SVG road | |
| if road_is_intersection: | |
| road_svg = ''' | |
| <rect x="0" y="160" width="700" height="100" fill="#555"/> | |
| <rect x="300" y="0" width="100" height="420" fill="#555"/> | |
| <line x1="0" y1="210" x2="300" y2="210" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/> | |
| <line x1="400" y1="210" x2="700" y2="210" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/> | |
| <line x1="350" y1="0" x2="350" y2="160" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/> | |
| <line x1="350" y1="260" x2="350" y2="420" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/> | |
| ''' | |
| else: | |
| road_svg = ''' | |
| <rect x="0" y="150" width="700" height="120" fill="#555" rx="2"/> | |
| <line x1="0" y1="210" x2="700" y2="210" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/> | |
| <line x1="0" y1="150" x2="700" y2="150" stroke="white" stroke-width="2"/> | |
| <line x1="0" y1="270" x2="700" y2="270" stroke="white" stroke-width="2"/> | |
| ''' | |
| # Vehicle 2 SVG (if present) | |
| v2_svg = "" | |
| if has_v2: | |
| if road_is_intersection: | |
| v2_svg = f'''<g> | |
| <animateTransform attributeName="transform" type="translate" from="0,0" to="0,135" dur="{dur}" fill="freeze"/> | |
| <rect x="325" y="60" width="50" height="26" rx="5" fill="{v2_color}" stroke="#fff" stroke-width="1"/> | |
| <text x="350" y="78" fill="white" font-size="10" font-weight="bold" text-anchor="middle">V2</text> | |
| </g>''' | |
| else: | |
| v2_svg = f'''<g> | |
| <animateTransform attributeName="transform" type="translate" from="0,0" to="-200,0" dur="{dur}" fill="freeze"/> | |
| <rect x="560" y="215" width="50" height="26" rx="5" fill="{v2_color}" stroke="#fff" stroke-width="1"/> | |
| <text x="585" y="233" fill="white" font-size="10" font-weight="bold" text-anchor="middle">V2</text> | |
| </g>''' | |
| html = f''' | |
| <div style="text-align:center; font-family: Inter, Arial, sans-serif;"> | |
| <svg id="anim_{uid}" width="700" height="420" viewBox="0 0 700 420" xmlns="http://www.w3.org/2000/svg" style="border:1px solid #444; border-radius:10px; background:#1a1a2e;"> | |
| <defs> | |
| <radialGradient id="glow_{uid}" cx="50%" cy="50%" r="50%"> | |
| <stop offset="0%" stop-color="#fbbf24" stop-opacity="0.8"/> | |
| <stop offset="100%" stop-color="#fbbf24" stop-opacity="0"/> | |
| </radialGradient> | |
| </defs> | |
| {road_svg} | |
| <!-- Vehicle 1 --> | |
| <g> | |
| <animateTransform attributeName="transform" type="translate" from="0,0" to="200,0" dur="{dur}" fill="freeze"/> | |
| <rect x="80" y="190" width="50" height="26" rx="5" fill="{v1_color}" stroke="#fff" stroke-width="1"/> | |
| <text x="105" y="207" fill="white" font-size="10" font-weight="bold" text-anchor="middle">V1</text> | |
| </g> | |
| {v2_svg} | |
| <!-- Impact flash --> | |
| <circle cx="340" cy="210" r="0" fill="url(#glow_{uid})"> | |
| <animate attributeName="r" values="0;0;0;0;0;0;0;45;55;0" dur="{dur}" fill="freeze"/> | |
| <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0.5;0" dur="{dur}" fill="freeze"/> | |
| </circle> | |
| <!-- Debris --> | |
| <circle cx="340" cy="210" r="3" fill="#fbbf24" opacity="0"> | |
| <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/> | |
| <animate attributeName="cx" values="340;340;340;340;340;340;340;310;290" dur="{dur}" fill="freeze"/> | |
| <animate attributeName="cy" values="210;210;210;210;210;210;210;185;170" dur="{dur}" fill="freeze"/> | |
| </circle> | |
| <circle cx="340" cy="210" r="2" fill="#ef4444" opacity="0"> | |
| <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/> | |
| <animate attributeName="cx" values="340;340;340;340;340;340;340;370;395" dur="{dur}" fill="freeze"/> | |
| <animate attributeName="cy" values="210;210;210;210;210;210;210;190;175" dur="{dur}" fill="freeze"/> | |
| </circle> | |
| <circle cx="340" cy="210" r="3" fill="#e2e8f0" opacity="0"> | |
| <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/> | |
| <animate attributeName="cx" values="340;340;340;340;340;340;340;320;305" dur="{dur}" fill="freeze"/> | |
| <animate attributeName="cy" values="210;210;210;210;210;210;210;235;255" dur="{dur}" fill="freeze"/> | |
| </circle> | |
| <circle cx="340" cy="210" r="2" fill="#f97316" opacity="0"> | |
| <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/> | |
| <animate attributeName="cx" values="340;340;340;340;340;340;340;365;385" dur="{dur}" fill="freeze"/> | |
| <animate attributeName="cy" values="210;210;210;210;210;210;210;230;250" dur="{dur}" fill="freeze"/> | |
| </circle> | |
| <!-- Collision label --> | |
| <text x="350" y="145" fill="#ef4444" font-size="18" font-weight="bold" text-anchor="middle" opacity="0" font-family="Inter, Arial, sans-serif"> | |
| COLLISION | |
| <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;1" dur="{dur}" fill="freeze"/> | |
| </text> | |
| <!-- HUD --> | |
| <rect x="10" y="350" width="680" height="60" rx="8" fill="rgba(0,0,0,0.6)"/> | |
| <text x="20" y="375" fill="#e2e8f0" font-size="12" font-family="Inter, Arial, sans-serif">{v1_make[:35]}</text> | |
| <text x="20" y="398" fill="#e2e8f0" font-size="12" font-family="Inter, Arial, sans-serif">{"" if not has_v2 else v2_make[:35]}{"Single vehicle accident" if not has_v2 else ""}</text> | |
| <text x="680" y="375" fill="{sev_color}" font-size="14" font-weight="bold" text-anchor="end" font-family="Inter, Arial, sans-serif">{category.upper()}</text> | |
| <text x="680" y="398" fill="#94a3b8" font-size="11" text-anchor="end" font-family="Inter, Arial, sans-serif">Impact: {impact} | Road: {road_type}</text> | |
| </svg> | |
| <div style="margin-top:8px; color:#94a3b8; font-size:12px;"> | |
| Vehicles: {num_v} | Animation auto-plays on load | |
| </div> | |
| </div> | |
| ''' | |
| return html | |
| # ββ Build Gradio App ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CUSTOM_CSS = """ | |
| .gradio-container { max-width: 1200px !important; } | |
| footer { display: none !important; } | |
| """ | |
| with gr.Blocks( | |
| title="TraceScene β AI Accident Analysis", | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π TraceScene | |
| ### AI-Powered Accident Scene Analysis | |
| *GPU-accelerated inference via ZeroGPU (NVIDIA H200)* | |
| --- | |
| """) | |
| with gr.Tabs(): | |
| # Tab 1: Quick Analyze (single photo) | |
| with gr.TabItem("β‘ Quick Analyze"): | |
| gr.Markdown("Upload a photo for instant GPU-accelerated analysis.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Upload Accident Photo", type="pil") | |
| quick_btn = gr.Button("π Analyze with GPU", variant="primary") | |
| with gr.Column(): | |
| quick_output = gr.Textbox(label="AI Analysis", lines=20) | |
| quick_btn.click(fn=gradio_analyze_photo, inputs=[input_image], outputs=[quick_output], api_name="analyze_photo") | |
| # Tab 2: Cases | |
| with gr.TabItem("π Cases"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Create Case") | |
| cn = gr.Textbox(label="Case Number *", placeholder="ACC-2026-001") | |
| on = gr.Textbox(label="Officer Name") | |
| loc = gr.Textbox(label="Location") | |
| dt = gr.Textbox(label="Incident Date", placeholder="YYYY-MM-DD") | |
| nt = gr.Textbox(label="Notes", lines=2) | |
| create_btn = gr.Button("Create Case", variant="primary") | |
| create_status = gr.Markdown() | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Existing Cases") | |
| cases_tbl = gr.Dataframe( | |
| headers=["ID", "Case #", "Officer", "Location", "Date", "Status", "Photos"], | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| refresh_btn = gr.Button("π Refresh") | |
| del_id = gr.Number(label="Case ID to Delete", precision=0) | |
| del_btn = gr.Button("ποΈ Delete", variant="stop") | |
| del_status = gr.Markdown() | |
| create_btn.click(create_case_fn, inputs=[cn, on, loc, dt, nt], outputs=[create_status, cases_tbl], api_name="create_case") | |
| refresh_btn.click(list_cases_fn, outputs=[cases_tbl], api_name="list_cases") | |
| del_btn.click(delete_case_fn, inputs=[del_id], outputs=[del_status, cases_tbl], api_name="delete_case") | |
| # Tab 3: Upload Photos | |
| with gr.TabItem("πΈ Photos"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| up_case = gr.Number(label="Case ID", precision=0) | |
| up_files = gr.File(label="Select Photos", file_count="multiple", file_types=["image"]) | |
| up_btn = gr.Button("Upload", variant="primary") | |
| up_status = gr.Markdown() | |
| with gr.Column(scale=2): | |
| pv_case = gr.Number(label="Preview Case ID", precision=0) | |
| pv_btn = gr.Button("Load Photos") | |
| gallery = gr.Gallery(label="Photos", columns=3) | |
| up_btn.click(upload_photos_fn, inputs=[up_case, up_files], outputs=[up_status], api_name="upload_photos") | |
| pv_btn.click(get_case_photos_fn, inputs=[pv_case], outputs=[gallery], api_name="get_case_photos") | |
| # Tab 4: Run Analysis | |
| with gr.TabItem("π§ Analysis"): | |
| gr.Markdown(""" | |
| ### Full Analysis Pipeline (GPU-accelerated) | |
| 1. Scene Analysis β 2. Rule Matching β 3. Fault Deduction | |
| """) | |
| an_case = gr.Number(label="Case ID", precision=0) | |
| an_btn = gr.Button("π Run Full Analysis", variant="primary", size="lg") | |
| an_status = gr.Markdown() | |
| with gr.Accordion("Scene Details", open=False): | |
| an_detail = gr.Markdown() | |
| an_violations = gr.Markdown(label="Violations & Fault") | |
| an_btn.click(run_analysis_fn, inputs=[an_case], outputs=[an_status, an_detail, an_violations], api_name="run_analysis") | |
| # Tab 5: Report | |
| with gr.TabItem("π Report"): | |
| rp_case = gr.Number(label="Case ID", precision=0) | |
| rp_btn = gr.Button("Generate Report", variant="primary") | |
| rp_out = gr.Markdown() | |
| rp_btn.click(generate_report_fn, inputs=[rp_case], outputs=[rp_out], api_name="generate_report") | |
| # Tab 6: Rules | |
| with gr.TabItem("π Rules"): | |
| ru_btn = gr.Button("Load Traffic Rules") | |
| ru_out = gr.Markdown() | |
| ru_btn.click(get_rules_fn, outputs=[ru_out], api_name="get_rules") | |
| # Tab 7: Chat Q&A | |
| with gr.TabItem("π¬ Chat"): | |
| gr.Markdown("### Case Q&A Chatbot\nAsk questions about logged cases, traffic rules, or insurance clauses.") | |
| with gr.Row(): | |
| chat_case_id = gr.Number(label="Case ID (optional)", precision=0) | |
| chat_load_btn = gr.Button("π Load Case Context", variant="secondary") | |
| chat_context_status = gr.Markdown(value="*No case loaded. You can still ask general traffic/insurance questions.*") | |
| chatbot = gr.Chatbot(label="Conversation", height=400) | |
| chat_input = gr.Textbox(label="Your Question", placeholder="e.g. What vehicles were involved? What rules were violated?", lines=2) | |
| with gr.Row(): | |
| chat_send_btn = gr.Button("π¬ Send", variant="primary") | |
| chat_clear_btn = gr.Button("ποΈ Clear") | |
| # State for context | |
| chat_system_ctx = gr.State(value="You are TraceScene AI assistant. You help insurers and investigating officers analyze accident cases, traffic rules, and insurance clauses. Answer concisely and accurately based on the context provided.") | |
| chat_load_btn.click(load_chat_context, inputs=[chat_case_id], outputs=[chat_system_ctx, chat_context_status]) | |
| chat_send_btn.click(chat_respond, inputs=[chat_input, chatbot, chat_system_ctx], outputs=[chatbot, chat_input, chat_system_ctx], api_name="chat") | |
| chat_input.submit(chat_respond, inputs=[chat_input, chatbot, chat_system_ctx], outputs=[chatbot, chat_input, chat_system_ctx]) | |
| chat_clear_btn.click(lambda: ([], ""), outputs=[chatbot, chat_input]) | |
| # Tab 8: 2D Animation | |
| with gr.Tab("Simulation"): | |
| gr.Markdown("### 2D Accident Simulation\nVisualize the top-down perspective of the incident.") | |
| anim_case_id = gr.Number(label="Case ID", precision=0) | |
| anim_btn = gr.Button("Generate Animation", variant="primary") | |
| anim_output = gr.HTML(label="Animation View") | |
| anim_btn.click(generate_animation_fn, inputs=[anim_case_id], outputs=[anim_output]) | |
| # Hidden API-only endpoints (for @gradio/client from custom frontend) | |
| with gr.TabItem("π API", visible=False): | |
| api_health_btn = gr.Button("health") | |
| api_health_out = gr.Textbox() | |
| api_health_btn.click(health_fn, outputs=[api_health_out], api_name="health") | |
| api_cases_btn = gr.Button("list_cases_json") | |
| api_cases_out = gr.Textbox() | |
| api_cases_btn.click(list_cases_json, outputs=[api_cases_out], api_name="list_cases_json") | |
| api_case_id = gr.Number(precision=0) | |
| api_case_btn = gr.Button("get_case") | |
| api_case_out = gr.Textbox() | |
| api_case_btn.click(get_case_json, inputs=[api_case_id], outputs=[api_case_out], api_name="get_case") | |
| api_report_id = gr.Number(precision=0) | |
| api_report_btn = gr.Button("get_report") | |
| api_report_out = gr.Textbox() | |
| api_report_btn.click(get_report_json, inputs=[api_report_id], outputs=[api_report_out], api_name="get_report_json") | |
| api_rules_btn = gr.Button("get_rules_json") | |
| api_rules_out = gr.Textbox() | |
| api_rules_btn.click(get_rules_json, outputs=[api_rules_out], api_name="get_rules_json") | |
| gr.Markdown("---\n*TraceScene β Built by Siddharth Ravikumar | tracescene@zohomail.ae*") | |
| # ββ Create FastAPI App & Mount Gradio ββββββββββββββββββββββββββββββββ | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Static files (frontend) | |
| frontend_dir = Path(__file__).resolve().parent / "frontend" | |
| if frontend_dir.exists(): | |
| # Mount specific subfolders to root for easier relative pathing | |
| app.mount("/css", StaticFiles(directory=str(frontend_dir / "css")), name="css") | |
| app.mount("/js", StaticFiles(directory=str(frontend_dir / "js")), name="js") | |
| app.mount("/images", StaticFiles(directory=str(frontend_dir / "images")), name="images") | |
| app.mount("/static", StaticFiles(directory=str(frontend_dir / "static")), name="static") | |
| # Serve uploads folder | |
| if settings.upload_path.exists(): | |
| app.mount("/uploads", StaticFiles(directory=str(settings.upload_path)), name="uploads") | |
| async def serve_frontend(): | |
| index_file = frontend_dir / "index.html" | |
| if index_file.exists(): | |
| return FileResponse(str(index_file)) | |
| return {"message": "TraceScene API", "docs": "/docs"} | |
| # API Routes | |
| app.include_router(router) | |
| # Mount Gradio app at /gradio | |
| app = gr.mount_gradio_app( | |
| app, | |
| demo, | |
| path="/" | |
| ) | |
| # Startup event wrapper | |
| async def startup_event(): | |
| logger.info("Starting up FastAPI application...") | |
| await _ensure_init() | |
| # --- Hugging Face ZeroGPU Fix --- | |
| # When using gr.mount_gradio_app with a custom FastAPI app, gr.Blocks.launch() | |
| # is bypassed. The `spaces` library hooks `.launch()` to emit the `startup_report` | |
| # required by ZeroGPU orchestrator to verify `@spaces.GPU` functions exist. | |
| # Without this report, the Hub errors out with "No @spaces.GPU function detected". | |
| # Therefore, we manually trigger it here. | |
| try: | |
| from spaces import config | |
| if getattr(config.Config, "zero_gpu", False): | |
| import spaces.zero as zero | |
| if hasattr(zero, "startup"): | |
| zero.startup() | |
| logger.info("Triggered ZeroGPU startup successfully.") | |
| elif hasattr(zero, "client"): | |
| zero.torch.pack() | |
| zero.client.startup_report() | |
| logger.info("Triggered ZeroGPU client startup manually.") | |
| except ImportError: | |
| pass | |
| except Exception as e: | |
| logger.warning(f"Failed to manually trigger ZeroGPU startup report: {e}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |