Molbap's picture
Molbap HF Staff
Update app with better diff, new style
4fe7080 verified
import ast
import threading
from uuid import uuid4
from collections import Counter
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from app.detector import (
CodeSimilarityAnalyzer,
MODELS_ROOT,
generate_agent_report,
get_default_hub_dataset,
)
BASE_DIR = Path(__file__).resolve().parent.parent
STATIC_DIR = BASE_DIR / "static"
app = FastAPI(title="Modular Model Graph")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
class AnalyzeRequest(BaseModel):
code: str = Field(..., min_length=1)
top_k: int = Field(default=5, ge=1, le=25)
granularity: str = "method"
precision: str = "float32"
hub_dataset: str | None = None
exclude_same_model: bool = True
exclude_identical: bool = True
exclude_models: list[str] = Field(default_factory=list)
source_path: str | None = None
class AstRequest(BaseModel):
code: str = Field(..., min_length=1)
symbol: str = Field(..., min_length=1)
match_identifier: str | None = None
_ANALYZERS: dict[tuple[str, str, str], CodeSimilarityAnalyzer] = {}
_JOBS: dict[str, dict[str, object]] = {}
_JOBS_LOCK = threading.Lock()
def _get_analyzer(precision: str, granularity: str, hub_dataset: str) -> CodeSimilarityAnalyzer:
key = (precision, granularity, hub_dataset)
if key in _ANALYZERS:
return _ANALYZERS[key]
analyzer = CodeSimilarityAnalyzer(
hub_dataset=hub_dataset,
precision=precision,
granularity=granularity,
)
_ANALYZERS[key] = analyzer
return analyzer
def _update_job(job_id: str, payload: dict[str, object]) -> None:
with _JOBS_LOCK:
job = _JOBS.get(job_id)
if job is None:
return
job.update(payload)
def _run_analysis(job_id: str, payload: dict[str, object]) -> None:
try:
_update_job(job_id, {"status": "running", "progress": 0, "total": 1, "message": "starting"})
hub_dataset = payload.get("hub_dataset") or get_default_hub_dataset()
analyzer = _get_analyzer(payload["precision"], payload["granularity"], hub_dataset)
def progress_cb(done: int, total: int, message: str) -> None:
_update_job(
job_id,
{"status": "running", "progress": done, "total": total, "message": message},
)
results = analyzer.analyze_code(
payload["code"],
top_k_per_item=payload["top_k"],
exclude_same_model=payload["exclude_same_model"],
exclude_identical=payload["exclude_identical"],
exclude_models=payload.get("exclude_models") or [],
source_path=payload["source_path"],
progress=progress_cb,
)
response = {
"results": results["results"],
"overall": results["overall"],
"by_class": results["by_class"],
"overall_all": results.get("overall_all", []),
"by_class_all": results.get("by_class_all", {}),
"identical_filtered": results.get("identical_filtered", 0),
"agent_report": generate_agent_report(results),
"index_info": analyzer.index_status(),
}
_update_job(job_id, {"status": "done", "message": "done", "result": response})
except Exception as exc:
error_text = f"{type(exc).__name__}: {exc}"
print(f"[analysis error] job={job_id} {error_text}")
_update_job(job_id, {"status": "error", "error": error_text, "message": "failed"})
def _find_definition(tree: ast.AST, symbol: str) -> ast.AST | None:
if "." in symbol:
class_name, method_name = symbol.split(".", 1)
for node in tree.body:
if isinstance(node, ast.ClassDef) and node.name == class_name:
for child in node.body:
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == method_name:
return child
return None
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == symbol:
return node
return None
def _call_name(node: ast.AST) -> str | None:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return node.attr
return None
def _summarize_ast(node: ast.AST) -> dict[str, list[dict[str, int]]]:
counts: Counter[str] = Counter()
calls: Counter[str] = Counter()
for child in ast.walk(node):
counts[type(child).__name__] += 1
if isinstance(child, ast.Call):
name = _call_name(child.func)
if name:
calls[name] += 1
node_counts = [{"name": name, "count": count} for name, count in counts.most_common(8)]
call_counts = [{"name": name, "count": count} for name, count in calls.most_common(8)]
return {"node_counts": node_counts, "calls": call_counts}
def _strip_node_docstring(node: ast.AST) -> None:
if not isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef)):
return
if not node.body:
return
first = node.body[0]
if isinstance(first, ast.Expr) and isinstance(getattr(first, "value", None), ast.Constant):
if isinstance(first.value.value, str):
node.body.pop(0)
def _get_structural_flow(node: ast.AST) -> str:
flow: list[str] = []
for child in ast.walk(node):
if isinstance(child, ast.Call):
name = None
if isinstance(child.func, ast.Attribute) and isinstance(child.func.value, ast.Name):
if child.func.value.id == "self":
name = f"self.{child.func.attr}"
if name is None:
name = _call_name(child.func)
if name:
flow.append(name)
elif isinstance(child, ast.Attribute):
if isinstance(child.value, ast.Name) and child.value.id == "self":
flow.append(f"self.{child.attr}")
elif isinstance(child, (ast.If, ast.While, ast.For)):
flow.append("[LOGIC]")
elif isinstance(child, ast.Return):
flow.append("Return")
reduced: list[str] = []
for item in flow:
if not reduced or reduced[-1] != item:
reduced.append(item)
return " -> ".join(reduced[:20])
def _extract_ast(source: str, symbol: str) -> tuple[str | None, dict[str, object] | None]:
tree = ast.parse(source)
node = _find_definition(tree, symbol)
if node is None:
return None, None
_strip_node_docstring(node)
try:
code = ast.unparse(node)
except Exception:
code = None
return code, {"summary": _summarize_ast(node), "flow": _get_structural_flow(node)}
@app.get("/")
async def index() -> FileResponse:
return FileResponse(STATIC_DIR / "index.html")
@app.post("/api/analyze")
async def analyze(request: AnalyzeRequest) -> dict:
job_id = uuid4().hex
payload = request.model_dump()
with _JOBS_LOCK:
_JOBS[job_id] = {
"status": "queued",
"progress": 0,
"total": 1,
"message": "queued",
"result": None,
}
thread = threading.Thread(target=_run_analysis, args=(job_id, payload), daemon=True)
thread.start()
return {"job_id": job_id}
@app.get("/api/progress/{job_id}")
async def progress(job_id: str) -> dict:
with _JOBS_LOCK:
job = _JOBS.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail="Job not found")
return {
"status": job.get("status"),
"progress": job.get("progress", 0),
"total": job.get("total", 1),
"message": job.get("message", ""),
"error": job.get("error"),
}
@app.get("/api/models")
async def models() -> dict:
return {"models": CodeSimilarityAnalyzer.list_models()}
@app.get("/api/result/{job_id}")
async def result(job_id: str) -> dict:
with _JOBS_LOCK:
job = _JOBS.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail="Job not found")
if job.get("status") == "error":
raise HTTPException(status_code=500, detail=job.get("error", "Unknown error"))
if job.get("status") != "done":
raise HTTPException(status_code=409, detail="Job not finished")
return job.get("result") or {}
@app.post("/api/ast")
async def ast_view(request: AstRequest) -> dict:
query_ast, query_summary = _extract_ast(request.code, request.symbol)
match_ast = None
match_summary = None
if request.match_identifier and ":" in request.match_identifier:
relative_path, match_name = request.match_identifier.split(":", 1)
file_path = MODELS_ROOT / relative_path
if file_path.exists():
match_source = file_path.read_text(encoding="utf-8")
match_ast, match_summary = _extract_ast(match_source, match_name)
return {
"query_ast": query_ast,
"query_summary": query_summary,
"match_ast": match_ast,
"match_summary": match_summary,
"match_identifier": request.match_identifier,
}