import ast import threading from uuid import uuid4 from collections import Counter from pathlib import Path from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field from app.detector import ( CodeSimilarityAnalyzer, MODELS_ROOT, generate_agent_report, get_default_hub_dataset, ) BASE_DIR = Path(__file__).resolve().parent.parent STATIC_DIR = BASE_DIR / "static" app = FastAPI(title="Modular Model Graph") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") class AnalyzeRequest(BaseModel): code: str = Field(..., min_length=1) top_k: int = Field(default=5, ge=1, le=25) granularity: str = "method" precision: str = "float32" hub_dataset: str | None = None exclude_same_model: bool = True exclude_identical: bool = True exclude_models: list[str] = Field(default_factory=list) source_path: str | None = None class AstRequest(BaseModel): code: str = Field(..., min_length=1) symbol: str = Field(..., min_length=1) match_identifier: str | None = None _ANALYZERS: dict[tuple[str, str, str], CodeSimilarityAnalyzer] = {} _JOBS: dict[str, dict[str, object]] = {} _JOBS_LOCK = threading.Lock() def _get_analyzer(precision: str, granularity: str, hub_dataset: str) -> CodeSimilarityAnalyzer: key = (precision, granularity, hub_dataset) if key in _ANALYZERS: return _ANALYZERS[key] analyzer = CodeSimilarityAnalyzer( hub_dataset=hub_dataset, precision=precision, granularity=granularity, ) _ANALYZERS[key] = analyzer return analyzer def _update_job(job_id: str, payload: dict[str, object]) -> None: with _JOBS_LOCK: job = _JOBS.get(job_id) if job is None: return job.update(payload) def _run_analysis(job_id: str, payload: dict[str, object]) -> None: try: _update_job(job_id, {"status": "running", "progress": 0, "total": 1, "message": "starting"}) hub_dataset = payload.get("hub_dataset") or get_default_hub_dataset() analyzer = _get_analyzer(payload["precision"], payload["granularity"], hub_dataset) def progress_cb(done: int, total: int, message: str) -> None: _update_job( job_id, {"status": "running", "progress": done, "total": total, "message": message}, ) results = analyzer.analyze_code( payload["code"], top_k_per_item=payload["top_k"], exclude_same_model=payload["exclude_same_model"], exclude_identical=payload["exclude_identical"], exclude_models=payload.get("exclude_models") or [], source_path=payload["source_path"], progress=progress_cb, ) response = { "results": results["results"], "overall": results["overall"], "by_class": results["by_class"], "overall_all": results.get("overall_all", []), "by_class_all": results.get("by_class_all", {}), "identical_filtered": results.get("identical_filtered", 0), "agent_report": generate_agent_report(results), "index_info": analyzer.index_status(), } _update_job(job_id, {"status": "done", "message": "done", "result": response}) except Exception as exc: error_text = f"{type(exc).__name__}: {exc}" print(f"[analysis error] job={job_id} {error_text}") _update_job(job_id, {"status": "error", "error": error_text, "message": "failed"}) def _find_definition(tree: ast.AST, symbol: str) -> ast.AST | None: if "." in symbol: class_name, method_name = symbol.split(".", 1) for node in tree.body: if isinstance(node, ast.ClassDef) and node.name == class_name: for child in node.body: if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == method_name: return child return None for node in tree.body: if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == symbol: return node return None def _call_name(node: ast.AST) -> str | None: if isinstance(node, ast.Name): return node.id if isinstance(node, ast.Attribute): return node.attr return None def _summarize_ast(node: ast.AST) -> dict[str, list[dict[str, int]]]: counts: Counter[str] = Counter() calls: Counter[str] = Counter() for child in ast.walk(node): counts[type(child).__name__] += 1 if isinstance(child, ast.Call): name = _call_name(child.func) if name: calls[name] += 1 node_counts = [{"name": name, "count": count} for name, count in counts.most_common(8)] call_counts = [{"name": name, "count": count} for name, count in calls.most_common(8)] return {"node_counts": node_counts, "calls": call_counts} def _strip_node_docstring(node: ast.AST) -> None: if not isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef)): return if not node.body: return first = node.body[0] if isinstance(first, ast.Expr) and isinstance(getattr(first, "value", None), ast.Constant): if isinstance(first.value.value, str): node.body.pop(0) def _get_structural_flow(node: ast.AST) -> str: flow: list[str] = [] for child in ast.walk(node): if isinstance(child, ast.Call): name = None if isinstance(child.func, ast.Attribute) and isinstance(child.func.value, ast.Name): if child.func.value.id == "self": name = f"self.{child.func.attr}" if name is None: name = _call_name(child.func) if name: flow.append(name) elif isinstance(child, ast.Attribute): if isinstance(child.value, ast.Name) and child.value.id == "self": flow.append(f"self.{child.attr}") elif isinstance(child, (ast.If, ast.While, ast.For)): flow.append("[LOGIC]") elif isinstance(child, ast.Return): flow.append("Return") reduced: list[str] = [] for item in flow: if not reduced or reduced[-1] != item: reduced.append(item) return " -> ".join(reduced[:20]) def _extract_ast(source: str, symbol: str) -> tuple[str | None, dict[str, object] | None]: tree = ast.parse(source) node = _find_definition(tree, symbol) if node is None: return None, None _strip_node_docstring(node) try: code = ast.unparse(node) except Exception: code = None return code, {"summary": _summarize_ast(node), "flow": _get_structural_flow(node)} @app.get("/") async def index() -> FileResponse: return FileResponse(STATIC_DIR / "index.html") @app.post("/api/analyze") async def analyze(request: AnalyzeRequest) -> dict: job_id = uuid4().hex payload = request.model_dump() with _JOBS_LOCK: _JOBS[job_id] = { "status": "queued", "progress": 0, "total": 1, "message": "queued", "result": None, } thread = threading.Thread(target=_run_analysis, args=(job_id, payload), daemon=True) thread.start() return {"job_id": job_id} @app.get("/api/progress/{job_id}") async def progress(job_id: str) -> dict: with _JOBS_LOCK: job = _JOBS.get(job_id) if job is None: raise HTTPException(status_code=404, detail="Job not found") return { "status": job.get("status"), "progress": job.get("progress", 0), "total": job.get("total", 1), "message": job.get("message", ""), "error": job.get("error"), } @app.get("/api/models") async def models() -> dict: return {"models": CodeSimilarityAnalyzer.list_models()} @app.get("/api/result/{job_id}") async def result(job_id: str) -> dict: with _JOBS_LOCK: job = _JOBS.get(job_id) if job is None: raise HTTPException(status_code=404, detail="Job not found") if job.get("status") == "error": raise HTTPException(status_code=500, detail=job.get("error", "Unknown error")) if job.get("status") != "done": raise HTTPException(status_code=409, detail="Job not finished") return job.get("result") or {} @app.post("/api/ast") async def ast_view(request: AstRequest) -> dict: query_ast, query_summary = _extract_ast(request.code, request.symbol) match_ast = None match_summary = None if request.match_identifier and ":" in request.match_identifier: relative_path, match_name = request.match_identifier.split(":", 1) file_path = MODELS_ROOT / relative_path if file_path.exists(): match_source = file_path.read_text(encoding="utf-8") match_ast, match_summary = _extract_ast(match_source, match_name) return { "query_ast": query_ast, "query_summary": query_summary, "match_ast": match_ast, "match_summary": match_summary, "match_identifier": request.match_identifier, }