Spaces:
Sleeping
Sleeping
| import ast | |
| import threading | |
| from uuid import uuid4 | |
| from collections import Counter | |
| from pathlib import Path | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel, Field | |
| from app.detector import ( | |
| CodeSimilarityAnalyzer, | |
| MODELS_ROOT, | |
| generate_agent_report, | |
| get_default_hub_dataset, | |
| ) | |
| BASE_DIR = Path(__file__).resolve().parent.parent | |
| STATIC_DIR = BASE_DIR / "static" | |
| app = FastAPI(title="Modular Model Graph") | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| class AnalyzeRequest(BaseModel): | |
| code: str = Field(..., min_length=1) | |
| top_k: int = Field(default=5, ge=1, le=25) | |
| granularity: str = "method" | |
| precision: str = "float32" | |
| hub_dataset: str | None = None | |
| exclude_same_model: bool = True | |
| exclude_identical: bool = True | |
| exclude_models: list[str] = Field(default_factory=list) | |
| source_path: str | None = None | |
| class AstRequest(BaseModel): | |
| code: str = Field(..., min_length=1) | |
| symbol: str = Field(..., min_length=1) | |
| match_identifier: str | None = None | |
| _ANALYZERS: dict[tuple[str, str, str], CodeSimilarityAnalyzer] = {} | |
| _JOBS: dict[str, dict[str, object]] = {} | |
| _JOBS_LOCK = threading.Lock() | |
| def _get_analyzer(precision: str, granularity: str, hub_dataset: str) -> CodeSimilarityAnalyzer: | |
| key = (precision, granularity, hub_dataset) | |
| if key in _ANALYZERS: | |
| return _ANALYZERS[key] | |
| analyzer = CodeSimilarityAnalyzer( | |
| hub_dataset=hub_dataset, | |
| precision=precision, | |
| granularity=granularity, | |
| ) | |
| _ANALYZERS[key] = analyzer | |
| return analyzer | |
| def _update_job(job_id: str, payload: dict[str, object]) -> None: | |
| with _JOBS_LOCK: | |
| job = _JOBS.get(job_id) | |
| if job is None: | |
| return | |
| job.update(payload) | |
| def _run_analysis(job_id: str, payload: dict[str, object]) -> None: | |
| try: | |
| _update_job(job_id, {"status": "running", "progress": 0, "total": 1, "message": "starting"}) | |
| hub_dataset = payload.get("hub_dataset") or get_default_hub_dataset() | |
| analyzer = _get_analyzer(payload["precision"], payload["granularity"], hub_dataset) | |
| def progress_cb(done: int, total: int, message: str) -> None: | |
| _update_job( | |
| job_id, | |
| {"status": "running", "progress": done, "total": total, "message": message}, | |
| ) | |
| results = analyzer.analyze_code( | |
| payload["code"], | |
| top_k_per_item=payload["top_k"], | |
| exclude_same_model=payload["exclude_same_model"], | |
| exclude_identical=payload["exclude_identical"], | |
| exclude_models=payload.get("exclude_models") or [], | |
| source_path=payload["source_path"], | |
| progress=progress_cb, | |
| ) | |
| response = { | |
| "results": results["results"], | |
| "overall": results["overall"], | |
| "by_class": results["by_class"], | |
| "overall_all": results.get("overall_all", []), | |
| "by_class_all": results.get("by_class_all", {}), | |
| "identical_filtered": results.get("identical_filtered", 0), | |
| "agent_report": generate_agent_report(results), | |
| "index_info": analyzer.index_status(), | |
| } | |
| _update_job(job_id, {"status": "done", "message": "done", "result": response}) | |
| except Exception as exc: | |
| error_text = f"{type(exc).__name__}: {exc}" | |
| print(f"[analysis error] job={job_id} {error_text}") | |
| _update_job(job_id, {"status": "error", "error": error_text, "message": "failed"}) | |
| def _find_definition(tree: ast.AST, symbol: str) -> ast.AST | None: | |
| if "." in symbol: | |
| class_name, method_name = symbol.split(".", 1) | |
| for node in tree.body: | |
| if isinstance(node, ast.ClassDef) and node.name == class_name: | |
| for child in node.body: | |
| if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == method_name: | |
| return child | |
| return None | |
| for node in tree.body: | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == symbol: | |
| return node | |
| return None | |
| def _call_name(node: ast.AST) -> str | None: | |
| if isinstance(node, ast.Name): | |
| return node.id | |
| if isinstance(node, ast.Attribute): | |
| return node.attr | |
| return None | |
| def _summarize_ast(node: ast.AST) -> dict[str, list[dict[str, int]]]: | |
| counts: Counter[str] = Counter() | |
| calls: Counter[str] = Counter() | |
| for child in ast.walk(node): | |
| counts[type(child).__name__] += 1 | |
| if isinstance(child, ast.Call): | |
| name = _call_name(child.func) | |
| if name: | |
| calls[name] += 1 | |
| node_counts = [{"name": name, "count": count} for name, count in counts.most_common(8)] | |
| call_counts = [{"name": name, "count": count} for name, count in calls.most_common(8)] | |
| return {"node_counts": node_counts, "calls": call_counts} | |
| def _strip_node_docstring(node: ast.AST) -> None: | |
| if not isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef)): | |
| return | |
| if not node.body: | |
| return | |
| first = node.body[0] | |
| if isinstance(first, ast.Expr) and isinstance(getattr(first, "value", None), ast.Constant): | |
| if isinstance(first.value.value, str): | |
| node.body.pop(0) | |
| def _get_structural_flow(node: ast.AST) -> str: | |
| flow: list[str] = [] | |
| for child in ast.walk(node): | |
| if isinstance(child, ast.Call): | |
| name = None | |
| if isinstance(child.func, ast.Attribute) and isinstance(child.func.value, ast.Name): | |
| if child.func.value.id == "self": | |
| name = f"self.{child.func.attr}" | |
| if name is None: | |
| name = _call_name(child.func) | |
| if name: | |
| flow.append(name) | |
| elif isinstance(child, ast.Attribute): | |
| if isinstance(child.value, ast.Name) and child.value.id == "self": | |
| flow.append(f"self.{child.attr}") | |
| elif isinstance(child, (ast.If, ast.While, ast.For)): | |
| flow.append("[LOGIC]") | |
| elif isinstance(child, ast.Return): | |
| flow.append("Return") | |
| reduced: list[str] = [] | |
| for item in flow: | |
| if not reduced or reduced[-1] != item: | |
| reduced.append(item) | |
| return " -> ".join(reduced[:20]) | |
| def _extract_ast(source: str, symbol: str) -> tuple[str | None, dict[str, object] | None]: | |
| tree = ast.parse(source) | |
| node = _find_definition(tree, symbol) | |
| if node is None: | |
| return None, None | |
| _strip_node_docstring(node) | |
| try: | |
| code = ast.unparse(node) | |
| except Exception: | |
| code = None | |
| return code, {"summary": _summarize_ast(node), "flow": _get_structural_flow(node)} | |
| async def index() -> FileResponse: | |
| return FileResponse(STATIC_DIR / "index.html") | |
| async def analyze(request: AnalyzeRequest) -> dict: | |
| job_id = uuid4().hex | |
| payload = request.model_dump() | |
| with _JOBS_LOCK: | |
| _JOBS[job_id] = { | |
| "status": "queued", | |
| "progress": 0, | |
| "total": 1, | |
| "message": "queued", | |
| "result": None, | |
| } | |
| thread = threading.Thread(target=_run_analysis, args=(job_id, payload), daemon=True) | |
| thread.start() | |
| return {"job_id": job_id} | |
| async def progress(job_id: str) -> dict: | |
| with _JOBS_LOCK: | |
| job = _JOBS.get(job_id) | |
| if job is None: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| return { | |
| "status": job.get("status"), | |
| "progress": job.get("progress", 0), | |
| "total": job.get("total", 1), | |
| "message": job.get("message", ""), | |
| "error": job.get("error"), | |
| } | |
| async def models() -> dict: | |
| return {"models": CodeSimilarityAnalyzer.list_models()} | |
| async def result(job_id: str) -> dict: | |
| with _JOBS_LOCK: | |
| job = _JOBS.get(job_id) | |
| if job is None: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| if job.get("status") == "error": | |
| raise HTTPException(status_code=500, detail=job.get("error", "Unknown error")) | |
| if job.get("status") != "done": | |
| raise HTTPException(status_code=409, detail="Job not finished") | |
| return job.get("result") or {} | |
| async def ast_view(request: AstRequest) -> dict: | |
| query_ast, query_summary = _extract_ast(request.code, request.symbol) | |
| match_ast = None | |
| match_summary = None | |
| if request.match_identifier and ":" in request.match_identifier: | |
| relative_path, match_name = request.match_identifier.split(":", 1) | |
| file_path = MODELS_ROOT / relative_path | |
| if file_path.exists(): | |
| match_source = file_path.read_text(encoding="utf-8") | |
| match_ast, match_summary = _extract_ast(match_source, match_name) | |
| return { | |
| "query_ast": query_ast, | |
| "query_summary": query_summary, | |
| "match_ast": match_ast, | |
| "match_summary": match_summary, | |
| "match_identifier": request.match_identifier, | |
| } | |