#!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = [ # "huggingface_hub>=0.21.0", # "requests", # ] # /// """ PitVQA Multi-Agent Orchestration System Specialized agents for methodologically rigorous VLM pipeline management: 1. JobMonitorAgent - Track HuggingFace Jobs status 2. CurationAgent - Quality-filter showcase examples 3. DatasetAgent - Validate image-embedded dataset 4. ModelVerifierAgent - Test merged model outputs 5. DemoSyncAgent - Update Gradio Space with results Run with: python pitvqa_agent_orchestrator.py """ import os import json import time from dataclasses import dataclass from typing import Dict, List, Optional, Any from datetime import datetime from enum import Enum # ============================================================ # Agent Status Types # ============================================================ class AgentStatus(Enum): IDLE = "idle" RUNNING = "running" SUCCESS = "success" FAILED = "failed" WAITING = "waiting" @dataclass class AgentResult: agent_name: str status: AgentStatus message: str data: Optional[Dict] = None timestamp: str = "" def __post_init__(self): if not self.timestamp: self.timestamp = datetime.now().isoformat() # ============================================================ # Base Agent # ============================================================ class BaseAgent: """Base class for all PitVQA agents.""" def __init__(self, name: str): self.name = name self.status = AgentStatus.IDLE self.results: List[AgentResult] = [] def log(self, message: str, level: str = "INFO"): icon = {"INFO": "ℹ️", "SUCCESS": "✅", "ERROR": "❌", "WARN": "⚠️"}.get(level, "📌") print(f"[{self.name}] {icon} {message}") def run(self) -> AgentResult: raise NotImplementedError def report(self) -> Dict: return { "agent": self.name, "status": self.status.value, "results": [r.__dict__ for r in self.results] } # ============================================================ # Agent 1: Job Monitor # ============================================================ class JobMonitorAgent(BaseAgent): """Monitors HuggingFace Jobs and reports status.""" def __init__(self, job_ids: List[str]): super().__init__("JobMonitor") self.job_ids = job_ids self.job_status = {} def check_job(self, job_id: str) -> Dict: """Check single job status using HF API.""" try: from huggingface_hub import HfApi api = HfApi() # Get job info job = api.get_job(job_id) return { "id": job_id, "status": job.status.stage if hasattr(job.status, 'stage') else str(job.status), "message": job.status.message if hasattr(job.status, 'message') else None } except Exception as e: return {"id": job_id, "status": "UNKNOWN", "error": str(e)} def run(self) -> AgentResult: self.status = AgentStatus.RUNNING self.log(f"Checking {len(self.job_ids)} jobs...") all_complete = True any_failed = False for job_id in self.job_ids: status = self.check_job(job_id) self.job_status[job_id] = status stage = status.get("status", "UNKNOWN") self.log(f"Job {job_id[:8]}: {stage}") if stage not in ["COMPLETED", "SUCCESS"]: all_complete = False if stage in ["FAILED", "ERROR"]: any_failed = True if any_failed: self.status = AgentStatus.FAILED return AgentResult(self.name, AgentStatus.FAILED, "Some jobs failed", self.job_status) elif all_complete: self.status = AgentStatus.SUCCESS return AgentResult(self.name, AgentStatus.SUCCESS, "All jobs complete", self.job_status) else: self.status = AgentStatus.WAITING return AgentResult(self.name, AgentStatus.WAITING, "Jobs still running", self.job_status) # ============================================================ # Agent 2: Curation Agent # ============================================================ class CurationAgent(BaseAgent): """Curates showcase examples based on quality criteria.""" QUALITY_CRITERIA = { "coordinate_validity": lambda x, y: 0 <= x <= 100 and 0 <= y <= 100, "coordinate_diversity": lambda coords: len(set(coords)) > len(coords) * 0.5, "video_diversity": lambda vids: len(set(vids)) >= min(5, len(vids)), "frame_diversity": lambda frames: len(set(frames)) >= min(8, len(frames)), } def __init__(self, results_path: str = "./curation_review/all_results.json"): super().__init__("Curation") self.results_path = results_path self.curated_examples = [] def load_results(self) -> List[Dict]: """Load raw curation results.""" try: with open(self.results_path) as f: return json.load(f) except FileNotFoundError: self.log("Results file not found - job may still be running", "WARN") return [] def score_example(self, example: Dict) -> float: """Score a single example (0-1).""" score = 0.0 # Basic validity if example.get("success"): score += 0.3 # Coordinate quality if example.get("task") == "point": x, y = example.get("x"), example.get("y") if x and y: # Penalize edge coordinates (likely failures) if 10 < x < 90 and 10 < y < 90: score += 0.3 else: score += 0.1 elif example.get("task") == "bbox": bbox = example.get("bbox") if bbox and len(bbox) == 4: # Penalize tiny or huge boxes area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) if 100 < area < 5000: score += 0.3 else: score += 0.1 # Response coherence response = example.get("response", "") if " List[Dict]: """Select best diverse examples.""" if not results: return [] # Score all examples scored = [(self.score_example(ex), ex) for ex in results if ex.get("success")] scored.sort(key=lambda x: x[0], reverse=True) # Ensure diversity curated = [] used_videos = set() used_frames = set() used_tasks = {"point": 0, "bbox": 0} for score, ex in scored: if len(curated) >= top_k: break video = ex.get("video_id") frame = ex.get("frame_idx") task = ex.get("task") # Diversity constraints if used_videos.count(video) >= 2: # Max 2 per video continue if (video, frame) in used_frames: # Unique video+frame combos continue if used_tasks.get(task, 0) >= top_k // 2: # Balance tasks continue curated.append({**ex, "quality_score": score}) used_videos.add(video) used_frames.add((video, frame)) used_tasks[task] = used_tasks.get(task, 0) + 1 return curated def run(self) -> AgentResult: self.status = AgentStatus.RUNNING self.log("Loading curation results...") results = self.load_results() if not results: self.status = AgentStatus.WAITING return AgentResult(self.name, AgentStatus.WAITING, "No results available yet") self.log(f"Scoring {len(results)} examples...") self.curated_examples = self.curate(results) if len(self.curated_examples) >= 8: self.status = AgentStatus.SUCCESS # Report diversity videos = set(ex["video_id"] for ex in self.curated_examples) frames = set(ex["frame_idx"] for ex in self.curated_examples) self.log(f"Curated {len(self.curated_examples)} examples", "SUCCESS") self.log(f" Videos: {len(videos)} unique") self.log(f" Frames: {len(frames)} unique") return AgentResult( self.name, AgentStatus.SUCCESS, f"Curated {len(self.curated_examples)} high-quality diverse examples", {"examples": self.curated_examples} ) else: self.status = AgentStatus.FAILED return AgentResult( self.name, AgentStatus.FAILED, f"Only {len(self.curated_examples)} examples passed quality checks" ) # ============================================================ # Agent 3: Dataset Validator # ============================================================ class DatasetValidatorAgent(BaseAgent): """Validates image-embedded dataset quality.""" def __init__(self, dataset_id: str = "mmrech/pitvqa-spatial-with-images"): super().__init__("DatasetValidator") self.dataset_id = dataset_id def run(self) -> AgentResult: self.status = AgentStatus.RUNNING self.log(f"Validating dataset: {self.dataset_id}") try: from datasets import load_dataset # Try to load dataset ds = load_dataset(self.dataset_id, split="train[:10]") # Check required fields required_fields = ["image", "messages"] missing = [f for f in required_fields if f not in ds.features] if missing: self.status = AgentStatus.FAILED return AgentResult( self.name, AgentStatus.FAILED, f"Missing fields: {missing}" ) # Validate image quality valid_images = 0 for ex in ds: img = ex.get("image") if img and hasattr(img, "size") and img.size[0] > 0: valid_images += 1 if valid_images == len(ds): self.status = AgentStatus.SUCCESS return AgentResult( self.name, AgentStatus.SUCCESS, f"Dataset valid: {valid_images}/{len(ds)} images OK", {"sample_count": len(ds), "valid_images": valid_images} ) else: self.status = AgentStatus.FAILED return AgentResult( self.name, AgentStatus.FAILED, f"Invalid images: {len(ds) - valid_images}/{len(ds)}" ) except Exception as e: self.status = AgentStatus.WAITING return AgentResult( self.name, AgentStatus.WAITING, f"Dataset not yet available: {e}" ) # ============================================================ # Agent 4: Model Verifier # ============================================================ class ModelVerifierAgent(BaseAgent): """Verifies merged model outputs are correct.""" TEST_PROMPTS = [ ("Point to the suction device", "point"), ("Draw a bounding box around the surgical instrument", "bbox"), ("What surgical phase is this?", "classification"), ] def __init__(self, model_id: str = "mmrech/pitvqa-qwen2vl-merged"): super().__init__("ModelVerifier") self.model_id = model_id def run(self) -> AgentResult: self.status = AgentStatus.RUNNING self.log(f"Verifying model: {self.model_id}") try: from huggingface_hub import HfApi api = HfApi() # Check if model exists try: info = api.model_info(self.model_id) self.log(f"Model found: {info.modelId}") # Check for required files files = [f.rfilename for f in info.siblings] required = ["config.json", "model.safetensors"] # Check if main model files exist has_model = any("safetensors" in f or "pytorch" in f for f in files) has_config = "config.json" in files if has_model and has_config: self.status = AgentStatus.SUCCESS return AgentResult( self.name, AgentStatus.SUCCESS, f"Model verified: {len(files)} files present", {"files": files[:10]} # First 10 files ) else: self.status = AgentStatus.FAILED return AgentResult( self.name, AgentStatus.FAILED, f"Missing model files (has_model={has_model}, has_config={has_config})" ) except Exception as e: self.status = AgentStatus.WAITING return AgentResult( self.name, AgentStatus.WAITING, f"Model not yet available: {e}" ) except Exception as e: self.status = AgentStatus.FAILED return AgentResult(self.name, AgentStatus.FAILED, f"Error: {e}") # ============================================================ # Agent 5: Training Specialist (HF-LLM-Trainer) # ============================================================ class TrainingSpecialistAgent(BaseAgent): """ Specialist in HuggingFace LLM Training (TRL/SFT/LoRA/DPO). Responsibilities: - Validate training configurations - Check adapter quality - Recommend training improvements - Verify LoRA/PEFT setup """ TRAINING_METHODS = { "SFT": "Supervised Fine-Tuning - learning from (input, output) pairs", "LoRA": "Low-Rank Adaptation - parameter-efficient adapters", "DPO": "Direct Preference Optimization - learning from preferences", "RLHF": "Reinforcement Learning from Human Feedback", } OPTIMAL_CONFIG = { "lora_r": 16, "lora_alpha": 32, "learning_rate": 1e-4, "batch_size": 1, "gradient_accumulation_steps": 16, "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"], } def __init__(self, adapter_repo: str = "mmrech/pitvqa-qwen2vl-unified-v2"): super().__init__("TrainingSpecialist") self.adapter_repo = adapter_repo def validate_adapter_config(self) -> Dict: """Validate adapter configuration.""" try: from huggingface_hub import hf_hub_download import json # Download adapter config config_path = hf_hub_download( repo_id=self.adapter_repo, filename="stage4/adapter_config.json" ) with open(config_path) as f: config = json.load(f) # Check key parameters issues = [] recommendations = [] # Check LoRA rank if config.get("r", 0) < 8: issues.append("LoRA rank too low (r < 8)") elif config.get("r", 0) > 64: recommendations.append("Consider reducing LoRA rank for efficiency") # Check target modules target_modules = config.get("target_modules", []) if not any("proj" in m for m in target_modules): issues.append("No projection layers targeted") return { "config": config, "issues": issues, "recommendations": recommendations, "valid": len(issues) == 0 } except Exception as e: return {"error": str(e), "valid": False} def recommend_next_training(self, current_metrics: Dict = None) -> Dict: """Recommend next training steps based on current metrics.""" recommendations = [] if not current_metrics: recommendations.append({ "priority": "HIGH", "action": "Run evaluation to get baseline metrics", "method": "scripts/evaluate_unified_vlm.py" }) else: accuracy = current_metrics.get("accuracy", 0) if accuracy < 0.7: recommendations.append({ "priority": "HIGH", "action": "Increase training epochs or data", "method": "SFT with more epochs" }) if accuracy >= 0.7 and accuracy < 0.85: recommendations.append({ "priority": "MEDIUM", "action": "Consider DPO for preference learning", "method": "Create chosen/rejected pairs from predictions" }) if accuracy >= 0.85: recommendations.append({ "priority": "LOW", "action": "Model performing well - focus on inference optimization", "method": "Merge adapters, quantize for deployment" }) return {"recommendations": recommendations} def run(self) -> AgentResult: self.status = AgentStatus.RUNNING self.log(f"Validating training setup: {self.adapter_repo}") # Validate adapter validation = self.validate_adapter_config() if validation.get("valid"): self.status = AgentStatus.SUCCESS recommendations = self.recommend_next_training() return AgentResult( self.name, AgentStatus.SUCCESS, f"Training config valid. LoRA r={validation['config'].get('r')}", { "config": validation["config"], "recommendations": recommendations["recommendations"] } ) elif validation.get("error"): self.status = AgentStatus.WAITING return AgentResult( self.name, AgentStatus.WAITING, f"Could not load adapter: {validation['error']}" ) else: self.status = AgentStatus.FAILED return AgentResult( self.name, AgentStatus.FAILED, f"Issues found: {validation['issues']}", validation ) # ============================================================ # Agent 6: Evaluation Specialist # ============================================================ class EvaluationSpecialistAgent(BaseAgent): """ Specialist in Model Evaluation (metrics, benchmarks, validation). Responsibilities: - Compute accuracy, F1, precision, recall - Validate coordinate predictions (MAE, quadrant accuracy) - Compare against baselines - Generate evaluation reports """ METRICS = { "classification": ["accuracy", "f1", "precision", "recall"], "localization": ["mae", "quadrant_accuracy", "distance_error"], "detection": ["iou", "ap", "ar"], } THRESHOLDS = { "quadrant_accuracy": 0.75, # Minimum acceptable "mae": 15.0, # Maximum acceptable (percentage) "classification_accuracy": 0.80, } def __init__(self, model_repo: str = "mmrech/pitvqa-qwen2vl-unified-v2"): super().__init__("EvaluationSpecialist") self.model_repo = model_repo self.metrics = {} def load_evaluation_results(self) -> Dict: """Load existing evaluation results if available.""" try: with open("evaluation_results.json") as f: return json.load(f) except FileNotFoundError: return {} def compute_quick_metrics(self, predictions: List[Dict]) -> Dict: """Compute quick metrics from predictions.""" if not predictions: return {} metrics = {} # Coordinate predictions coord_preds = [p for p in predictions if p.get("task") in ["point", "pointing"]] if coord_preds: valid = [p for p in coord_preds if p.get("x") is not None] metrics["valid_rate"] = len(valid) / len(coord_preds) # Calculate MAE if ground truth available errors = [] for p in valid: if p.get("gt_x") and p.get("gt_y"): err = ((p["x"] - p["gt_x"])**2 + (p["y"] - p["gt_y"])**2)**0.5 errors.append(err) if errors: metrics["mae"] = sum(errors) / len(errors) metrics["quadrant_accuracy"] = sum(1 for e in errors if e < 25) / len(errors) # Classification predictions class_preds = [p for p in predictions if p.get("task") == "classification"] if class_preds: correct = sum(1 for p in class_preds if p.get("prediction") == p.get("ground_truth")) metrics["classification_accuracy"] = correct / len(class_preds) return metrics def evaluate_against_thresholds(self, metrics: Dict) -> Dict: """Check metrics against quality thresholds.""" results = {"passed": [], "failed": [], "warnings": []} for metric, threshold in self.THRESHOLDS.items(): if metric in metrics: value = metrics[metric] if metric == "mae": passed = value <= threshold else: passed = value >= threshold entry = {"metric": metric, "value": value, "threshold": threshold} if passed: results["passed"].append(entry) else: results["failed"].append(entry) return results def generate_report(self, metrics: Dict, threshold_results: Dict) -> str: """Generate evaluation report.""" report = [] report.append("=" * 50) report.append("EVALUATION REPORT") report.append("=" * 50) report.append("\n📊 METRICS:") for k, v in metrics.items(): report.append(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}") report.append("\n✅ PASSED:") for item in threshold_results["passed"]: report.append(f" {item['metric']}: {item['value']:.4f} (threshold: {item['threshold']})") if threshold_results["failed"]: report.append("\n❌ FAILED:") for item in threshold_results["failed"]: report.append(f" {item['metric']}: {item['value']:.4f} (threshold: {item['threshold']})") return "\n".join(report) def run(self, predictions: List[Dict] = None) -> AgentResult: self.status = AgentStatus.RUNNING self.log("Running evaluation...") # Try to load existing results existing = self.load_evaluation_results() if existing: self.log("Found existing evaluation results") self.metrics = existing elif predictions: self.log(f"Computing metrics from {len(predictions)} predictions") self.metrics = self.compute_quick_metrics(predictions) else: self.status = AgentStatus.WAITING return AgentResult( self.name, AgentStatus.WAITING, "No predictions available for evaluation" ) # Check against thresholds threshold_results = self.evaluate_against_thresholds(self.metrics) # Generate report report = self.generate_report(self.metrics, threshold_results) self.log(f"\n{report}") if threshold_results["failed"]: self.status = AgentStatus.FAILED return AgentResult( self.name, AgentStatus.FAILED, f"{len(threshold_results['failed'])} metrics below threshold", {"metrics": self.metrics, "thresholds": threshold_results} ) else: self.status = AgentStatus.SUCCESS return AgentResult( self.name, AgentStatus.SUCCESS, f"All {len(threshold_results['passed'])} metrics passed", {"metrics": self.metrics, "thresholds": threshold_results} ) # ============================================================ # Agent 7: Demo Sync Agent # ============================================================ class DemoSyncAgent(BaseAgent): """Syncs curated examples to Gradio Space.""" def __init__(self, space_id: str = "mmrech/pitvqa-surgical-vlm"): super().__init__("DemoSync") self.space_id = space_id def run(self, curated_examples: List[Dict] = None) -> AgentResult: self.status = AgentStatus.RUNNING self.log(f"Syncing to Space: {self.space_id}") if not curated_examples: self.status = AgentStatus.WAITING return AgentResult( self.name, AgentStatus.WAITING, "No curated examples to sync" ) try: from huggingface_hub import HfApi api = HfApi() # Check Space status try: info = api.space_info(self.space_id) runtime = info.runtime if runtime and runtime.stage == "RUNNING": self.log(f"Space is running", "SUCCESS") # Create examples JSON for sync examples_json = json.dumps(curated_examples, indent=2) self.status = AgentStatus.SUCCESS return AgentResult( self.name, AgentStatus.SUCCESS, f"Space running, {len(curated_examples)} examples ready for sync", {"space_status": "RUNNING", "examples_count": len(curated_examples)} ) else: self.status = AgentStatus.WAITING return AgentResult( self.name, AgentStatus.WAITING, f"Space not running: {runtime.stage if runtime else 'unknown'}" ) except Exception as e: self.status = AgentStatus.FAILED return AgentResult(self.name, AgentStatus.FAILED, f"Space error: {e}") except Exception as e: self.status = AgentStatus.FAILED return AgentResult(self.name, AgentStatus.FAILED, f"Error: {e}") # ============================================================ # Orchestrator # ============================================================ class PitVQAOrchestrator: """Coordinates all agents for the PitVQA pipeline.""" def __init__(self, job_ids: List[str]): self.agents = { "monitor": JobMonitorAgent(job_ids), "curation": CurationAgent(), "dataset": DatasetValidatorAgent(), "model": ModelVerifierAgent(), "training": TrainingSpecialistAgent(), # HF-LLM-Trainer specialist "evaluation": EvaluationSpecialistAgent(), # Eval-Model specialist "demo": DemoSyncAgent(), } self.results = {} self.run_count = 0 def run_cycle(self) -> Dict: """Run one orchestration cycle.""" self.run_count += 1 print(f"\n{'='*60}") print(f"🔄 ORCHESTRATION CYCLE {self.run_count}") print(f"{'='*60}") # Phase 1: Check job status print("\n📊 Phase 1: Job Monitoring") monitor_result = self.agents["monitor"].run() self.results["monitor"] = monitor_result # Phase 2: Training Specialist - Validate adapter config print("\n🎓 Phase 2: Training Validation (HF-LLM-Trainer)") training_result = self.agents["training"].run() self.results["training"] = training_result # Phase 3: If jobs complete, run downstream agents if monitor_result.status in [AgentStatus.SUCCESS, AgentStatus.WAITING]: # Run curation print("\n🎨 Phase 3: Curation") curation_result = self.agents["curation"].run() self.results["curation"] = curation_result # Run dataset validation print("\n📦 Phase 4: Dataset Validation") dataset_result = self.agents["dataset"].run() self.results["dataset"] = dataset_result # Run model verification print("\n🤖 Phase 5: Model Verification") model_result = self.agents["model"].run() self.results["model"] = model_result # Run evaluation specialist print("\n📈 Phase 6: Evaluation (Metrics & Quality)") curated = curation_result.data.get("examples", []) if curation_result.data else [] eval_result = self.agents["evaluation"].run(predictions=curated) self.results["evaluation"] = eval_result # Run demo sync if curation succeeded print("\n🌐 Phase 7: Demo Sync") demo_result = self.agents["demo"].run(curated) self.results["demo"] = demo_result return self.generate_report() def generate_report(self) -> Dict: """Generate comprehensive status report.""" report = { "timestamp": datetime.now().isoformat(), "cycle": self.run_count, "overall_status": self._compute_overall_status(), "agents": {} } for name, result in self.results.items(): report["agents"][name] = { "status": result.status.value, "message": result.message } return report def _compute_overall_status(self) -> str: """Compute overall pipeline status.""" statuses = [r.status for r in self.results.values()] if all(s == AgentStatus.SUCCESS for s in statuses): return "COMPLETE" elif any(s == AgentStatus.FAILED for s in statuses): return "NEEDS_ATTENTION" elif any(s == AgentStatus.WAITING for s in statuses): return "IN_PROGRESS" else: return "UNKNOWN" def print_summary(self, report: Dict): """Print human-readable summary.""" print(f"\n{'='*60}") print("📋 ORCHESTRATION SUMMARY") print(f"{'='*60}") print(f"Time: {report['timestamp']}") print(f"Cycle: {report['cycle']}") print(f"Overall: {report['overall_status']}") print("\nAgent Status:") for name, info in report["agents"].items(): icon = {"success": "✅", "failed": "❌", "waiting": "⏳", "running": "🔄"}.get(info["status"], "❓") print(f" {icon} {name}: {info['status']} - {info['message'][:50]}") # ============================================================ # Main # ============================================================ def main(): print("🚀 PitVQA Multi-Agent Orchestrator Starting...") # Current job IDs job_ids = [ "696cfe9946affbb321046bd9", # Curation job "696cfebf57a10a9d296ca042", # Merge job ] orchestrator = PitVQAOrchestrator(job_ids) # Run orchestration cycle report = orchestrator.run_cycle() orchestrator.print_summary(report) # Save report with open("orchestration_report.json", "w") as f: json.dump(report, f, indent=2) print(f"\n💾 Report saved to orchestration_report.json") return report if __name__ == "__main__": main()