Spaces:
Running
Running
| """ | |
| Orchestrator β coordinates Security β Performance β Fix agents | |
| and emits SSE events for real-time streaming to the frontend. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import os | |
| import time | |
| from typing import Any, AsyncGenerator, Dict, List, Optional | |
| from api.models import ( | |
| AMDMigrationGuide, | |
| AMDMigrationFindingModel, | |
| AnalysisSummary, | |
| PerformanceFinding, | |
| PrivacyCertificate, | |
| SecurityFinding, | |
| SessionResult, | |
| Severity, | |
| ) | |
| from agents.security_agent import SecurityAgent | |
| from agents.performance_agent import PerformanceAgent | |
| from agents.fix_agent import FixAgent | |
| from agents.amd_migration_advisor import AMDMigrationAdvisor | |
| from amd_metrics import AMDMetricsCollector | |
| from memory.session_store import get_store | |
| from privacy.privacy_guard import ZeroDataRetentionGuard | |
| from tools.code_parser import ( | |
| FileEntry, | |
| build_context_block, | |
| parse_code_string, | |
| parse_directory, | |
| parse_zip_base64, | |
| ) | |
| from tools.github_connector import GitHubConnector | |
| from tools.benchmark_tool import start_benchmark, record_first_finding, finish_benchmark | |
| logger = logging.getLogger(__name__) | |
| # Config from environment | |
| VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://localhost:8080/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct") | |
| LLM_API_KEY = os.getenv("LLM_API_KEY") or os.getenv("GROQ_API_KEY", "not-needed-local") | |
| USE_LLM = os.getenv("USE_LLM", "true").lower() == "true" | |
| def _sse_event(event: str, data: Dict[str, Any]) -> Dict[str, Any]: | |
| return {"event": event, "data": data} | |
| class Orchestrator: | |
| """ | |
| Master agent. Runs the full analysis pipeline: | |
| 1. Ingest code (GitHub / string / zip) | |
| 2. Security Agent (static + LLM) | |
| 3. Performance Agent (static + LLM) | |
| 4. Fix Agent (diffs + report) | |
| 5. Privacy certificate generation | |
| Yields SSE event dicts throughout for real-time streaming. | |
| """ | |
| def __init__(self) -> None: | |
| self.security_agent = SecurityAgent( | |
| vllm_base_url=VLLM_BASE_URL, | |
| model=MODEL_NAME, | |
| api_key=LLM_API_KEY | |
| ) | |
| self.performance_agent = PerformanceAgent( | |
| vllm_base_url=VLLM_BASE_URL, | |
| model=MODEL_NAME, | |
| api_key=LLM_API_KEY | |
| ) | |
| self.fix_agent = FixAgent( | |
| vllm_base_url=VLLM_BASE_URL, | |
| model=MODEL_NAME, | |
| api_key=LLM_API_KEY | |
| ) | |
| self.migration_advisor = AMDMigrationAdvisor() | |
| self.metrics_collector = AMDMetricsCollector() | |
| self.store = get_store() | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| # SSE streaming pipeline | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| async def run_stream( | |
| self, | |
| source: str, | |
| source_type: str, | |
| session_id: str, | |
| ) -> AsyncGenerator[Dict[str, Any], None]: | |
| """ | |
| Full analysis pipeline yielding SSE event dicts. | |
| Call from a FastAPI StreamingResponse / EventSourceResponse. | |
| """ | |
| start_time = time.perf_counter() | |
| bench = start_benchmark() | |
| self.metrics_collector.reset_tokens() | |
| # Update session | |
| await self.store.update(session_id, {"source_type": source_type, "status": "running"}) | |
| # ββ AMD Metrics background poller ββββββββββββββββββββ | |
| metrics_queue: asyncio.Queue = asyncio.Queue() | |
| metrics_stop = asyncio.Event() | |
| async def _poll_amd_metrics() -> None: | |
| """Collect AMD GPU metrics every 2 seconds.""" | |
| try: | |
| while not metrics_stop.is_set(): | |
| snapshot = await self.metrics_collector.collect() | |
| await metrics_queue.put(snapshot) | |
| await asyncio.sleep(2) | |
| except asyncio.CancelledError: | |
| pass | |
| except Exception as exc: | |
| logger.debug("[Orchestrator] AMD metrics polling error: %s", exc) | |
| metrics_task = asyncio.create_task(_poll_amd_metrics()) | |
| with ZeroDataRetentionGuard(session_id=session_id, enforce_network_block=False) as guard: | |
| # ββ Step 1: Ingest βββββββββββββββββββββββββββββββββββ | |
| yield _sse_event("status", {"message": "Ingesting code...", "session_id": session_id}) | |
| try: | |
| files = await asyncio.to_thread(self._ingest, source, source_type) | |
| except Exception as exc: | |
| metrics_stop.set() | |
| metrics_task.cancel() | |
| yield _sse_event("error", {"message": f"Ingestion failed: {exc}"}) | |
| await self.store.set_status(session_id, "error") | |
| return | |
| yield _sse_event("status", { | |
| "message": f"Loaded {len(files)} file(s)", | |
| "files_count": len(files), | |
| }) | |
| code_context = build_context_block(files) | |
| # Drain any queued AMD metrics | |
| while not metrics_queue.empty(): | |
| try: | |
| snapshot = metrics_queue.get_nowait() | |
| yield _sse_event("amd_metrics", snapshot) | |
| except asyncio.QueueEmpty: | |
| break | |
| # ββ Step 2: Security Agent βββββββββββββββββββββββββββ | |
| yield _sse_event("agent_start", {"agent": "security", "status": "scanning"}) | |
| # Static scan first (fast) | |
| static_security = await asyncio.to_thread( | |
| self.security_agent.static_scan, files | |
| ) | |
| for i, finding in enumerate(static_security): | |
| finding.id = f"SEC-STATIC-{i+1}" | |
| record_first_finding(bench) | |
| yield _sse_event("finding", { | |
| "agent": "security", | |
| **finding.model_dump(), | |
| }) | |
| await asyncio.sleep(0) # yield control to event loop | |
| # Drain AMD metrics between agents | |
| while not metrics_queue.empty(): | |
| try: | |
| yield _sse_event("amd_metrics", metrics_queue.get_nowait()) | |
| except asyncio.QueueEmpty: | |
| break | |
| # LLM deep scan | |
| if USE_LLM: | |
| llm_security = await self.security_agent.llm_scan(code_context, static_security) | |
| # Merge with static | |
| security_findings = self.security_agent._merge_findings(static_security, llm_security) | |
| security_findings = self.security_agent._sort_by_severity(security_findings) | |
| # Emit LLM-enriched findings | |
| for i, finding in enumerate(llm_security): | |
| finding.id = f"SEC-LLM-{i+1}" | |
| record_first_finding(bench) | |
| yield _sse_event("finding", { | |
| "agent": "security", | |
| **finding.model_dump(), | |
| }) | |
| await asyncio.sleep(0) | |
| else: | |
| security_findings = static_security | |
| yield _sse_event("agent_complete", { | |
| "agent": "security", | |
| "findings_count": len(security_findings), | |
| }) | |
| # ββ Step 3: Performance Agent ββββββββββββββββββββββββ | |
| yield _sse_event("agent_start", {"agent": "performance", "status": "analyzing"}) | |
| perf_findings = await self.performance_agent.analyze( | |
| files, code_context, use_llm=USE_LLM | |
| ) | |
| for i, pf in enumerate(perf_findings): | |
| pf.id = f"PERF-{i+1}" | |
| yield _sse_event("finding", { | |
| "agent": "performance", | |
| "type": pf.type.value, | |
| "saving_mb": pf.saving_mb or 0, | |
| "suggestion": pf.suggestion, | |
| **pf.model_dump(), | |
| }) | |
| await asyncio.sleep(0) | |
| yield _sse_event("agent_complete", { | |
| "agent": "performance", | |
| "optimizations_count": len(perf_findings), | |
| }) | |
| # Drain AMD metrics | |
| while not metrics_queue.empty(): | |
| try: | |
| yield _sse_event("amd_metrics", metrics_queue.get_nowait()) | |
| except asyncio.QueueEmpty: | |
| break | |
| # ββ Step 3.5: AMD Migration Advisor ββββββββββββββββββ | |
| amd_migration_result: Optional[Dict] = None | |
| try: | |
| amd_migration_result = await self.migration_advisor.scan(files) | |
| for mf in amd_migration_result.get("findings", []): | |
| yield _sse_event("amd_migration_finding", mf) | |
| await asyncio.sleep(0.05) | |
| yield _sse_event("amd_migration_summary", { | |
| "compatibility_score": amd_migration_result["compatibility_score"], | |
| "compatibility_label": amd_migration_result["compatibility_label"], | |
| "total_cuda_patterns_found": amd_migration_result["total_cuda_patterns_found"], | |
| "summary": amd_migration_result["summary"], | |
| }) | |
| except Exception as exc: | |
| logger.warning("[Orchestrator] AMD migration scan failed: %s", exc) | |
| # ββ Step 4: Fix Agent ββββββββββββββββββββββββββββββββ | |
| yield _sse_event("agent_start", {"agent": "fix", "status": "generating_fixes"}) | |
| fix_result = await self.fix_agent.generate_fixes( | |
| files=files, | |
| security_findings=security_findings, | |
| performance_findings=perf_findings, | |
| session_id=session_id, | |
| use_llm=USE_LLM, | |
| ) | |
| # Emit individual fixes for the UI | |
| for fix in fix_result.finding_fixes: | |
| yield _sse_event("fix_ready", fix.model_dump()) | |
| await asyncio.sleep(0.1) # tiny delay for UI animation | |
| yield _sse_event("fix_batch", { | |
| "diff": fix_result.diffs[0].diff if fix_result.diffs else "", | |
| "files_changed": fix_result.files_changed, | |
| "diffs": [d.model_dump() for d in fix_result.diffs], | |
| }) | |
| # ββ Step 5: Summary & Certificate βββββββββββββββββββ | |
| # Stop AMD metrics polling | |
| metrics_stop.set() | |
| metrics_task.cancel() | |
| bench = finish_benchmark(bench, findings=len(security_findings)) | |
| elapsed = time.perf_counter() - start_time | |
| sev_counts = {s.value: 0 for s in Severity} | |
| for f in security_findings: | |
| sev_counts[f.severity.value] += 1 | |
| total_mem_saving = sum((pf.saving_mb or 0.0) for pf in perf_findings) | |
| summary = AnalysisSummary( | |
| session_id=session_id, | |
| total_findings=len(security_findings), | |
| critical_count=sev_counts.get("critical", 0), | |
| high_count=sev_counts.get("high", 0), | |
| medium_count=sev_counts.get("medium", 0), | |
| low_count=sev_counts.get("low", 0), | |
| performance_optimizations=len(perf_findings), | |
| estimated_memory_savings_mb=total_mem_saving, | |
| analysis_duration_seconds=round(elapsed, 2), | |
| files_analyzed=len(files), | |
| ) | |
| cert_dict = guard.generate_certificate() | |
| privacy_cert = PrivacyCertificate( | |
| session_id=cert_dict["session_id"], | |
| timestamp=cert_dict["timestamp"], | |
| guarantee=cert_dict["guarantee"], | |
| model_endpoint=cert_dict["model_endpoint"], | |
| external_calls_blocked=cert_dict.get("external_calls_blocked", []), | |
| data_wiped=cert_dict["data_wiped"], | |
| signature=cert_dict["signature"], | |
| ) | |
| # Build AMD migration guide for the final result | |
| amd_guide = None | |
| if amd_migration_result: | |
| try: | |
| amd_guide = AMDMigrationGuide( | |
| compatibility_score=amd_migration_result["compatibility_score"], | |
| compatibility_label=amd_migration_result["compatibility_label"], | |
| total_cuda_patterns_found=amd_migration_result["total_cuda_patterns_found"], | |
| findings=[ | |
| AMDMigrationFindingModel(**f) | |
| for f in amd_migration_result.get("findings", []) | |
| ], | |
| summary=amd_migration_result.get("summary", ""), | |
| ) | |
| except Exception as exc: | |
| logger.debug("[Orchestrator] AMDMigrationGuide build failed: %s", exc) | |
| # Persist full result to session store | |
| session_result = SessionResult( | |
| session_id=session_id, | |
| status="complete", | |
| summary=summary, | |
| security_findings=security_findings, | |
| performance_findings=perf_findings, | |
| fix_result=fix_result, | |
| privacy_certificate=privacy_cert, | |
| amd_migration_guide=amd_guide, | |
| ) | |
| await self.store.update(session_id, { | |
| "_status": "complete", | |
| "result": session_result.model_dump(mode="json"), | |
| }) | |
| yield _sse_event("complete", { | |
| "privacy_certificate": privacy_cert.model_dump(), | |
| "summary": summary.model_dump(), | |
| "security_report_available": True, | |
| "amd_migration_guide": amd_guide.model_dump() if amd_guide else None, | |
| }) | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| # Code ingestion | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| def _ingest(self, source: str, source_type: str) -> List[FileEntry]: | |
| """Route ingestion to the correct parser based on source_type.""" | |
| if source_type == "github": | |
| with GitHubConnector(source) as repo_dir: | |
| return parse_directory(repo_dir) | |
| elif source_type == "huggingface": | |
| from tools.huggingface_connector import HuggingFaceConnector | |
| with HuggingFaceConnector(source) as repo_dir: | |
| return parse_directory(repo_dir) | |
| elif source_type == "zip": | |
| return parse_zip_base64(source) | |
| elif source_type == "code": | |
| return parse_code_string(source, filename="input.py") | |
| else: | |
| raise ValueError(f"Unknown source_type: {source_type!r}") | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| # Demo mode (pre-computed, no GPU needed) | |
| # ββββββββββββββββββββββββββββββββββββββββββ | |
| async def run_demo(self, session_id: str = "demo") -> SessionResult: | |
| """ | |
| Return a pre-computed demo result using the vulnerable_ml_code fixture. | |
| Works without a GPU or vLLM server. | |
| """ | |
| import pathlib | |
| fixture_path = ( | |
| pathlib.Path(__file__).parent.parent | |
| / "tests" / "fixtures" / "vulnerable_ml_code.py" | |
| ) | |
| code = fixture_path.read_text(encoding="utf-8") if fixture_path.exists() else DEMO_CODE | |
| files: List[FileEntry] = [("vulnerable_ml_code.py", code)] | |
| code_context = build_context_block(files) | |
| # Static-only analysis (no LLM) for demo | |
| security_findings = self.security_agent.static_scan(files) | |
| perf_findings = self.performance_agent.static_scan(files) | |
| fix_result = await self.fix_agent.generate_fixes( | |
| files, security_findings, perf_findings, session_id, use_llm=False | |
| ) | |
| sev_counts = {s.value: 0 for s in Severity} | |
| for f in security_findings: | |
| sev_counts[f.severity.value] += 1 | |
| summary = AnalysisSummary( | |
| session_id=session_id, | |
| total_findings=len(security_findings), | |
| critical_count=sev_counts.get("critical", 0), | |
| high_count=sev_counts.get("high", 0), | |
| medium_count=sev_counts.get("medium", 0), | |
| low_count=sev_counts.get("low", 0), | |
| performance_optimizations=len(perf_findings), | |
| estimated_memory_savings_mb=sum((p.saving_mb or 0) for p in perf_findings), | |
| analysis_duration_seconds=0.5, | |
| files_analyzed=1, | |
| ) | |
| cert = PrivacyCertificate( | |
| session_id=session_id, | |
| timestamp="demo", | |
| guarantee="Demo mode β all inference ran locally (static analysis only).", | |
| model_endpoint="http://localhost:8080", | |
| external_calls_blocked=[], | |
| data_wiped=True, | |
| signature="demo-signature", | |
| ) | |
| return SessionResult( | |
| session_id=session_id, | |
| status="complete", | |
| summary=summary, | |
| security_findings=security_findings, | |
| performance_findings=perf_findings, | |
| fix_result=fix_result, | |
| privacy_certificate=cert, | |
| ) | |
| # Minimal inline demo code (fallback if fixture file missing) | |
| DEMO_CODE = ''' | |
| import pickle, os | |
| from flask import Flask, request | |
| app = Flask(__name__) | |
| HF_TOKEN = "hf_abcdefghijklmnopqrstuvwxyz123456" | |
| @app.route("/predict", methods=["POST"]) | |
| def predict(): | |
| model_path = request.json["model_path"] | |
| model = pickle.load(open(model_path, "rb")) # CWE-502 | |
| user_prompt = request.json["prompt"] | |
| result = model.generate(f"Answer: {user_prompt}") # LLM01 | |
| eval(result) # LLM02 | |
| return {"result": result} | |
| ''' | |