|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
PitVQA Multi-Agent Orchestration System |
|
|
|
|
|
Specialized agents for methodologically rigorous VLM pipeline management: |
|
|
1. JobMonitorAgent - Track HuggingFace Jobs status |
|
|
2. CurationAgent - Quality-filter showcase examples |
|
|
3. DatasetAgent - Validate image-embedded dataset |
|
|
4. ModelVerifierAgent - Test merged model outputs |
|
|
5. DemoSyncAgent - Update Gradio Space with results |
|
|
|
|
|
Run with: python pitvqa_agent_orchestrator.py |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import time |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Optional, Any |
|
|
from datetime import datetime |
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentStatus(Enum): |
|
|
IDLE = "idle" |
|
|
RUNNING = "running" |
|
|
SUCCESS = "success" |
|
|
FAILED = "failed" |
|
|
WAITING = "waiting" |
|
|
|
|
|
@dataclass |
|
|
class AgentResult: |
|
|
agent_name: str |
|
|
status: AgentStatus |
|
|
message: str |
|
|
data: Optional[Dict] = None |
|
|
timestamp: str = "" |
|
|
|
|
|
def __post_init__(self): |
|
|
if not self.timestamp: |
|
|
self.timestamp = datetime.now().isoformat() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseAgent: |
|
|
"""Base class for all PitVQA agents.""" |
|
|
|
|
|
def __init__(self, name: str): |
|
|
self.name = name |
|
|
self.status = AgentStatus.IDLE |
|
|
self.results: List[AgentResult] = [] |
|
|
|
|
|
def log(self, message: str, level: str = "INFO"): |
|
|
icon = {"INFO": "βΉοΈ", "SUCCESS": "β
", "ERROR": "β", "WARN": "β οΈ"}.get(level, "π") |
|
|
print(f"[{self.name}] {icon} {message}") |
|
|
|
|
|
def run(self) -> AgentResult: |
|
|
raise NotImplementedError |
|
|
|
|
|
def report(self) -> Dict: |
|
|
return { |
|
|
"agent": self.name, |
|
|
"status": self.status.value, |
|
|
"results": [r.__dict__ for r in self.results] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class JobMonitorAgent(BaseAgent): |
|
|
"""Monitors HuggingFace Jobs and reports status.""" |
|
|
|
|
|
def __init__(self, job_ids: List[str]): |
|
|
super().__init__("JobMonitor") |
|
|
self.job_ids = job_ids |
|
|
self.job_status = {} |
|
|
|
|
|
def check_job(self, job_id: str) -> Dict: |
|
|
"""Check single job status using HF API.""" |
|
|
try: |
|
|
from huggingface_hub import HfApi |
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
job = api.get_job(job_id) |
|
|
return { |
|
|
"id": job_id, |
|
|
"status": job.status.stage if hasattr(job.status, 'stage') else str(job.status), |
|
|
"message": job.status.message if hasattr(job.status, 'message') else None |
|
|
} |
|
|
except Exception as e: |
|
|
return {"id": job_id, "status": "UNKNOWN", "error": str(e)} |
|
|
|
|
|
def run(self) -> AgentResult: |
|
|
self.status = AgentStatus.RUNNING |
|
|
self.log(f"Checking {len(self.job_ids)} jobs...") |
|
|
|
|
|
all_complete = True |
|
|
any_failed = False |
|
|
|
|
|
for job_id in self.job_ids: |
|
|
status = self.check_job(job_id) |
|
|
self.job_status[job_id] = status |
|
|
|
|
|
stage = status.get("status", "UNKNOWN") |
|
|
self.log(f"Job {job_id[:8]}: {stage}") |
|
|
|
|
|
if stage not in ["COMPLETED", "SUCCESS"]: |
|
|
all_complete = False |
|
|
if stage in ["FAILED", "ERROR"]: |
|
|
any_failed = True |
|
|
|
|
|
if any_failed: |
|
|
self.status = AgentStatus.FAILED |
|
|
return AgentResult(self.name, AgentStatus.FAILED, "Some jobs failed", self.job_status) |
|
|
elif all_complete: |
|
|
self.status = AgentStatus.SUCCESS |
|
|
return AgentResult(self.name, AgentStatus.SUCCESS, "All jobs complete", self.job_status) |
|
|
else: |
|
|
self.status = AgentStatus.WAITING |
|
|
return AgentResult(self.name, AgentStatus.WAITING, "Jobs still running", self.job_status) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CurationAgent(BaseAgent): |
|
|
"""Curates showcase examples based on quality criteria.""" |
|
|
|
|
|
QUALITY_CRITERIA = { |
|
|
"coordinate_validity": lambda x, y: 0 <= x <= 100 and 0 <= y <= 100, |
|
|
"coordinate_diversity": lambda coords: len(set(coords)) > len(coords) * 0.5, |
|
|
"video_diversity": lambda vids: len(set(vids)) >= min(5, len(vids)), |
|
|
"frame_diversity": lambda frames: len(set(frames)) >= min(8, len(frames)), |
|
|
} |
|
|
|
|
|
def __init__(self, results_path: str = "./curation_review/all_results.json"): |
|
|
super().__init__("Curation") |
|
|
self.results_path = results_path |
|
|
self.curated_examples = [] |
|
|
|
|
|
def load_results(self) -> List[Dict]: |
|
|
"""Load raw curation results.""" |
|
|
try: |
|
|
with open(self.results_path) as f: |
|
|
return json.load(f) |
|
|
except FileNotFoundError: |
|
|
self.log("Results file not found - job may still be running", "WARN") |
|
|
return [] |
|
|
|
|
|
def score_example(self, example: Dict) -> float: |
|
|
"""Score a single example (0-1).""" |
|
|
score = 0.0 |
|
|
|
|
|
|
|
|
if example.get("success"): |
|
|
score += 0.3 |
|
|
|
|
|
|
|
|
if example.get("task") == "point": |
|
|
x, y = example.get("x"), example.get("y") |
|
|
if x and y: |
|
|
|
|
|
if 10 < x < 90 and 10 < y < 90: |
|
|
score += 0.3 |
|
|
else: |
|
|
score += 0.1 |
|
|
elif example.get("task") == "bbox": |
|
|
bbox = example.get("bbox") |
|
|
if bbox and len(bbox) == 4: |
|
|
|
|
|
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
|
|
if 100 < area < 5000: |
|
|
score += 0.3 |
|
|
else: |
|
|
score += 0.1 |
|
|
|
|
|
|
|
|
response = example.get("response", "") |
|
|
if "<point" in response or "<box" in response: |
|
|
score += 0.2 |
|
|
|
|
|
|
|
|
target = example.get("target", "") |
|
|
if target in response.lower(): |
|
|
score += 0.2 |
|
|
|
|
|
return min(score, 1.0) |
|
|
|
|
|
def curate(self, results: List[Dict], top_k: int = 12) -> List[Dict]: |
|
|
"""Select best diverse examples.""" |
|
|
if not results: |
|
|
return [] |
|
|
|
|
|
|
|
|
scored = [(self.score_example(ex), ex) for ex in results if ex.get("success")] |
|
|
scored.sort(key=lambda x: x[0], reverse=True) |
|
|
|
|
|
|
|
|
curated = [] |
|
|
used_videos = set() |
|
|
used_frames = set() |
|
|
used_tasks = {"point": 0, "bbox": 0} |
|
|
|
|
|
for score, ex in scored: |
|
|
if len(curated) >= top_k: |
|
|
break |
|
|
|
|
|
video = ex.get("video_id") |
|
|
frame = ex.get("frame_idx") |
|
|
task = ex.get("task") |
|
|
|
|
|
|
|
|
if used_videos.count(video) >= 2: |
|
|
continue |
|
|
if (video, frame) in used_frames: |
|
|
continue |
|
|
if used_tasks.get(task, 0) >= top_k // 2: |
|
|
continue |
|
|
|
|
|
curated.append({**ex, "quality_score": score}) |
|
|
used_videos.add(video) |
|
|
used_frames.add((video, frame)) |
|
|
used_tasks[task] = used_tasks.get(task, 0) + 1 |
|
|
|
|
|
return curated |
|
|
|
|
|
def run(self) -> AgentResult: |
|
|
self.status = AgentStatus.RUNNING |
|
|
self.log("Loading curation results...") |
|
|
|
|
|
results = self.load_results() |
|
|
if not results: |
|
|
self.status = AgentStatus.WAITING |
|
|
return AgentResult(self.name, AgentStatus.WAITING, "No results available yet") |
|
|
|
|
|
self.log(f"Scoring {len(results)} examples...") |
|
|
self.curated_examples = self.curate(results) |
|
|
|
|
|
if len(self.curated_examples) >= 8: |
|
|
self.status = AgentStatus.SUCCESS |
|
|
|
|
|
|
|
|
videos = set(ex["video_id"] for ex in self.curated_examples) |
|
|
frames = set(ex["frame_idx"] for ex in self.curated_examples) |
|
|
|
|
|
self.log(f"Curated {len(self.curated_examples)} examples", "SUCCESS") |
|
|
self.log(f" Videos: {len(videos)} unique") |
|
|
self.log(f" Frames: {len(frames)} unique") |
|
|
|
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.SUCCESS, |
|
|
f"Curated {len(self.curated_examples)} high-quality diverse examples", |
|
|
{"examples": self.curated_examples} |
|
|
) |
|
|
else: |
|
|
self.status = AgentStatus.FAILED |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.FAILED, |
|
|
f"Only {len(self.curated_examples)} examples passed quality checks" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatasetValidatorAgent(BaseAgent): |
|
|
"""Validates image-embedded dataset quality.""" |
|
|
|
|
|
def __init__(self, dataset_id: str = "mmrech/pitvqa-spatial-with-images"): |
|
|
super().__init__("DatasetValidator") |
|
|
self.dataset_id = dataset_id |
|
|
|
|
|
def run(self) -> AgentResult: |
|
|
self.status = AgentStatus.RUNNING |
|
|
self.log(f"Validating dataset: {self.dataset_id}") |
|
|
|
|
|
try: |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
ds = load_dataset(self.dataset_id, split="train[:10]") |
|
|
|
|
|
|
|
|
required_fields = ["image", "messages"] |
|
|
missing = [f for f in required_fields if f not in ds.features] |
|
|
|
|
|
if missing: |
|
|
self.status = AgentStatus.FAILED |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.FAILED, |
|
|
f"Missing fields: {missing}" |
|
|
) |
|
|
|
|
|
|
|
|
valid_images = 0 |
|
|
for ex in ds: |
|
|
img = ex.get("image") |
|
|
if img and hasattr(img, "size") and img.size[0] > 0: |
|
|
valid_images += 1 |
|
|
|
|
|
if valid_images == len(ds): |
|
|
self.status = AgentStatus.SUCCESS |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.SUCCESS, |
|
|
f"Dataset valid: {valid_images}/{len(ds)} images OK", |
|
|
{"sample_count": len(ds), "valid_images": valid_images} |
|
|
) |
|
|
else: |
|
|
self.status = AgentStatus.FAILED |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.FAILED, |
|
|
f"Invalid images: {len(ds) - valid_images}/{len(ds)}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
self.status = AgentStatus.WAITING |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.WAITING, |
|
|
f"Dataset not yet available: {e}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelVerifierAgent(BaseAgent): |
|
|
"""Verifies merged model outputs are correct.""" |
|
|
|
|
|
TEST_PROMPTS = [ |
|
|
("Point to the suction device", "point"), |
|
|
("Draw a bounding box around the surgical instrument", "bbox"), |
|
|
("What surgical phase is this?", "classification"), |
|
|
] |
|
|
|
|
|
def __init__(self, model_id: str = "mmrech/pitvqa-qwen2vl-merged"): |
|
|
super().__init__("ModelVerifier") |
|
|
self.model_id = model_id |
|
|
|
|
|
def run(self) -> AgentResult: |
|
|
self.status = AgentStatus.RUNNING |
|
|
self.log(f"Verifying model: {self.model_id}") |
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfApi |
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
try: |
|
|
info = api.model_info(self.model_id) |
|
|
self.log(f"Model found: {info.modelId}") |
|
|
|
|
|
|
|
|
files = [f.rfilename for f in info.siblings] |
|
|
required = ["config.json", "model.safetensors"] |
|
|
|
|
|
|
|
|
has_model = any("safetensors" in f or "pytorch" in f for f in files) |
|
|
has_config = "config.json" in files |
|
|
|
|
|
if has_model and has_config: |
|
|
self.status = AgentStatus.SUCCESS |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.SUCCESS, |
|
|
f"Model verified: {len(files)} files present", |
|
|
{"files": files[:10]} |
|
|
) |
|
|
else: |
|
|
self.status = AgentStatus.FAILED |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.FAILED, |
|
|
f"Missing model files (has_model={has_model}, has_config={has_config})" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
self.status = AgentStatus.WAITING |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.WAITING, |
|
|
f"Model not yet available: {e}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
self.status = AgentStatus.FAILED |
|
|
return AgentResult(self.name, AgentStatus.FAILED, f"Error: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainingSpecialistAgent(BaseAgent): |
|
|
""" |
|
|
Specialist in HuggingFace LLM Training (TRL/SFT/LoRA/DPO). |
|
|
|
|
|
Responsibilities: |
|
|
- Validate training configurations |
|
|
- Check adapter quality |
|
|
- Recommend training improvements |
|
|
- Verify LoRA/PEFT setup |
|
|
""" |
|
|
|
|
|
TRAINING_METHODS = { |
|
|
"SFT": "Supervised Fine-Tuning - learning from (input, output) pairs", |
|
|
"LoRA": "Low-Rank Adaptation - parameter-efficient adapters", |
|
|
"DPO": "Direct Preference Optimization - learning from preferences", |
|
|
"RLHF": "Reinforcement Learning from Human Feedback", |
|
|
} |
|
|
|
|
|
OPTIMAL_CONFIG = { |
|
|
"lora_r": 16, |
|
|
"lora_alpha": 32, |
|
|
"learning_rate": 1e-4, |
|
|
"batch_size": 1, |
|
|
"gradient_accumulation_steps": 16, |
|
|
"target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"], |
|
|
} |
|
|
|
|
|
def __init__(self, adapter_repo: str = "mmrech/pitvqa-qwen2vl-unified-v2"): |
|
|
super().__init__("TrainingSpecialist") |
|
|
self.adapter_repo = adapter_repo |
|
|
|
|
|
def validate_adapter_config(self) -> Dict: |
|
|
"""Validate adapter configuration.""" |
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
import json |
|
|
|
|
|
|
|
|
config_path = hf_hub_download( |
|
|
repo_id=self.adapter_repo, |
|
|
filename="stage4/adapter_config.json" |
|
|
) |
|
|
|
|
|
with open(config_path) as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
issues = [] |
|
|
recommendations = [] |
|
|
|
|
|
|
|
|
if config.get("r", 0) < 8: |
|
|
issues.append("LoRA rank too low (r < 8)") |
|
|
elif config.get("r", 0) > 64: |
|
|
recommendations.append("Consider reducing LoRA rank for efficiency") |
|
|
|
|
|
|
|
|
target_modules = config.get("target_modules", []) |
|
|
if not any("proj" in m for m in target_modules): |
|
|
issues.append("No projection layers targeted") |
|
|
|
|
|
return { |
|
|
"config": config, |
|
|
"issues": issues, |
|
|
"recommendations": recommendations, |
|
|
"valid": len(issues) == 0 |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e), "valid": False} |
|
|
|
|
|
def recommend_next_training(self, current_metrics: Dict = None) -> Dict: |
|
|
"""Recommend next training steps based on current metrics.""" |
|
|
recommendations = [] |
|
|
|
|
|
if not current_metrics: |
|
|
recommendations.append({ |
|
|
"priority": "HIGH", |
|
|
"action": "Run evaluation to get baseline metrics", |
|
|
"method": "scripts/evaluate_unified_vlm.py" |
|
|
}) |
|
|
else: |
|
|
accuracy = current_metrics.get("accuracy", 0) |
|
|
|
|
|
if accuracy < 0.7: |
|
|
recommendations.append({ |
|
|
"priority": "HIGH", |
|
|
"action": "Increase training epochs or data", |
|
|
"method": "SFT with more epochs" |
|
|
}) |
|
|
|
|
|
if accuracy >= 0.7 and accuracy < 0.85: |
|
|
recommendations.append({ |
|
|
"priority": "MEDIUM", |
|
|
"action": "Consider DPO for preference learning", |
|
|
"method": "Create chosen/rejected pairs from predictions" |
|
|
}) |
|
|
|
|
|
if accuracy >= 0.85: |
|
|
recommendations.append({ |
|
|
"priority": "LOW", |
|
|
"action": "Model performing well - focus on inference optimization", |
|
|
"method": "Merge adapters, quantize for deployment" |
|
|
}) |
|
|
|
|
|
return {"recommendations": recommendations} |
|
|
|
|
|
def run(self) -> AgentResult: |
|
|
self.status = AgentStatus.RUNNING |
|
|
self.log(f"Validating training setup: {self.adapter_repo}") |
|
|
|
|
|
|
|
|
validation = self.validate_adapter_config() |
|
|
|
|
|
if validation.get("valid"): |
|
|
self.status = AgentStatus.SUCCESS |
|
|
recommendations = self.recommend_next_training() |
|
|
|
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.SUCCESS, |
|
|
f"Training config valid. LoRA r={validation['config'].get('r')}", |
|
|
{ |
|
|
"config": validation["config"], |
|
|
"recommendations": recommendations["recommendations"] |
|
|
} |
|
|
) |
|
|
elif validation.get("error"): |
|
|
self.status = AgentStatus.WAITING |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.WAITING, |
|
|
f"Could not load adapter: {validation['error']}" |
|
|
) |
|
|
else: |
|
|
self.status = AgentStatus.FAILED |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.FAILED, |
|
|
f"Issues found: {validation['issues']}", |
|
|
validation |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EvaluationSpecialistAgent(BaseAgent): |
|
|
""" |
|
|
Specialist in Model Evaluation (metrics, benchmarks, validation). |
|
|
|
|
|
Responsibilities: |
|
|
- Compute accuracy, F1, precision, recall |
|
|
- Validate coordinate predictions (MAE, quadrant accuracy) |
|
|
- Compare against baselines |
|
|
- Generate evaluation reports |
|
|
""" |
|
|
|
|
|
METRICS = { |
|
|
"classification": ["accuracy", "f1", "precision", "recall"], |
|
|
"localization": ["mae", "quadrant_accuracy", "distance_error"], |
|
|
"detection": ["iou", "ap", "ar"], |
|
|
} |
|
|
|
|
|
THRESHOLDS = { |
|
|
"quadrant_accuracy": 0.75, |
|
|
"mae": 15.0, |
|
|
"classification_accuracy": 0.80, |
|
|
} |
|
|
|
|
|
def __init__(self, model_repo: str = "mmrech/pitvqa-qwen2vl-unified-v2"): |
|
|
super().__init__("EvaluationSpecialist") |
|
|
self.model_repo = model_repo |
|
|
self.metrics = {} |
|
|
|
|
|
def load_evaluation_results(self) -> Dict: |
|
|
"""Load existing evaluation results if available.""" |
|
|
try: |
|
|
with open("evaluation_results.json") as f: |
|
|
return json.load(f) |
|
|
except FileNotFoundError: |
|
|
return {} |
|
|
|
|
|
def compute_quick_metrics(self, predictions: List[Dict]) -> Dict: |
|
|
"""Compute quick metrics from predictions.""" |
|
|
if not predictions: |
|
|
return {} |
|
|
|
|
|
metrics = {} |
|
|
|
|
|
|
|
|
coord_preds = [p for p in predictions if p.get("task") in ["point", "pointing"]] |
|
|
if coord_preds: |
|
|
valid = [p for p in coord_preds if p.get("x") is not None] |
|
|
metrics["valid_rate"] = len(valid) / len(coord_preds) |
|
|
|
|
|
|
|
|
errors = [] |
|
|
for p in valid: |
|
|
if p.get("gt_x") and p.get("gt_y"): |
|
|
err = ((p["x"] - p["gt_x"])**2 + (p["y"] - p["gt_y"])**2)**0.5 |
|
|
errors.append(err) |
|
|
|
|
|
if errors: |
|
|
metrics["mae"] = sum(errors) / len(errors) |
|
|
metrics["quadrant_accuracy"] = sum(1 for e in errors if e < 25) / len(errors) |
|
|
|
|
|
|
|
|
class_preds = [p for p in predictions if p.get("task") == "classification"] |
|
|
if class_preds: |
|
|
correct = sum(1 for p in class_preds if p.get("prediction") == p.get("ground_truth")) |
|
|
metrics["classification_accuracy"] = correct / len(class_preds) |
|
|
|
|
|
return metrics |
|
|
|
|
|
def evaluate_against_thresholds(self, metrics: Dict) -> Dict: |
|
|
"""Check metrics against quality thresholds.""" |
|
|
results = {"passed": [], "failed": [], "warnings": []} |
|
|
|
|
|
for metric, threshold in self.THRESHOLDS.items(): |
|
|
if metric in metrics: |
|
|
value = metrics[metric] |
|
|
if metric == "mae": |
|
|
passed = value <= threshold |
|
|
else: |
|
|
passed = value >= threshold |
|
|
|
|
|
entry = {"metric": metric, "value": value, "threshold": threshold} |
|
|
if passed: |
|
|
results["passed"].append(entry) |
|
|
else: |
|
|
results["failed"].append(entry) |
|
|
|
|
|
return results |
|
|
|
|
|
def generate_report(self, metrics: Dict, threshold_results: Dict) -> str: |
|
|
"""Generate evaluation report.""" |
|
|
report = [] |
|
|
report.append("=" * 50) |
|
|
report.append("EVALUATION REPORT") |
|
|
report.append("=" * 50) |
|
|
|
|
|
report.append("\nπ METRICS:") |
|
|
for k, v in metrics.items(): |
|
|
report.append(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}") |
|
|
|
|
|
report.append("\nβ
PASSED:") |
|
|
for item in threshold_results["passed"]: |
|
|
report.append(f" {item['metric']}: {item['value']:.4f} (threshold: {item['threshold']})") |
|
|
|
|
|
if threshold_results["failed"]: |
|
|
report.append("\nβ FAILED:") |
|
|
for item in threshold_results["failed"]: |
|
|
report.append(f" {item['metric']}: {item['value']:.4f} (threshold: {item['threshold']})") |
|
|
|
|
|
return "\n".join(report) |
|
|
|
|
|
def run(self, predictions: List[Dict] = None) -> AgentResult: |
|
|
self.status = AgentStatus.RUNNING |
|
|
self.log("Running evaluation...") |
|
|
|
|
|
|
|
|
existing = self.load_evaluation_results() |
|
|
|
|
|
if existing: |
|
|
self.log("Found existing evaluation results") |
|
|
self.metrics = existing |
|
|
elif predictions: |
|
|
self.log(f"Computing metrics from {len(predictions)} predictions") |
|
|
self.metrics = self.compute_quick_metrics(predictions) |
|
|
else: |
|
|
self.status = AgentStatus.WAITING |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.WAITING, |
|
|
"No predictions available for evaluation" |
|
|
) |
|
|
|
|
|
|
|
|
threshold_results = self.evaluate_against_thresholds(self.metrics) |
|
|
|
|
|
|
|
|
report = self.generate_report(self.metrics, threshold_results) |
|
|
self.log(f"\n{report}") |
|
|
|
|
|
if threshold_results["failed"]: |
|
|
self.status = AgentStatus.FAILED |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.FAILED, |
|
|
f"{len(threshold_results['failed'])} metrics below threshold", |
|
|
{"metrics": self.metrics, "thresholds": threshold_results} |
|
|
) |
|
|
else: |
|
|
self.status = AgentStatus.SUCCESS |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.SUCCESS, |
|
|
f"All {len(threshold_results['passed'])} metrics passed", |
|
|
{"metrics": self.metrics, "thresholds": threshold_results} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DemoSyncAgent(BaseAgent): |
|
|
"""Syncs curated examples to Gradio Space.""" |
|
|
|
|
|
def __init__(self, space_id: str = "mmrech/pitvqa-surgical-vlm"): |
|
|
super().__init__("DemoSync") |
|
|
self.space_id = space_id |
|
|
|
|
|
def run(self, curated_examples: List[Dict] = None) -> AgentResult: |
|
|
self.status = AgentStatus.RUNNING |
|
|
self.log(f"Syncing to Space: {self.space_id}") |
|
|
|
|
|
if not curated_examples: |
|
|
self.status = AgentStatus.WAITING |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.WAITING, |
|
|
"No curated examples to sync" |
|
|
) |
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfApi |
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
try: |
|
|
info = api.space_info(self.space_id) |
|
|
runtime = info.runtime |
|
|
|
|
|
if runtime and runtime.stage == "RUNNING": |
|
|
self.log(f"Space is running", "SUCCESS") |
|
|
|
|
|
|
|
|
examples_json = json.dumps(curated_examples, indent=2) |
|
|
|
|
|
self.status = AgentStatus.SUCCESS |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.SUCCESS, |
|
|
f"Space running, {len(curated_examples)} examples ready for sync", |
|
|
{"space_status": "RUNNING", "examples_count": len(curated_examples)} |
|
|
) |
|
|
else: |
|
|
self.status = AgentStatus.WAITING |
|
|
return AgentResult( |
|
|
self.name, |
|
|
AgentStatus.WAITING, |
|
|
f"Space not running: {runtime.stage if runtime else 'unknown'}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
self.status = AgentStatus.FAILED |
|
|
return AgentResult(self.name, AgentStatus.FAILED, f"Space error: {e}") |
|
|
|
|
|
except Exception as e: |
|
|
self.status = AgentStatus.FAILED |
|
|
return AgentResult(self.name, AgentStatus.FAILED, f"Error: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PitVQAOrchestrator: |
|
|
"""Coordinates all agents for the PitVQA pipeline.""" |
|
|
|
|
|
def __init__(self, job_ids: List[str]): |
|
|
self.agents = { |
|
|
"monitor": JobMonitorAgent(job_ids), |
|
|
"curation": CurationAgent(), |
|
|
"dataset": DatasetValidatorAgent(), |
|
|
"model": ModelVerifierAgent(), |
|
|
"training": TrainingSpecialistAgent(), |
|
|
"evaluation": EvaluationSpecialistAgent(), |
|
|
"demo": DemoSyncAgent(), |
|
|
} |
|
|
self.results = {} |
|
|
self.run_count = 0 |
|
|
|
|
|
def run_cycle(self) -> Dict: |
|
|
"""Run one orchestration cycle.""" |
|
|
self.run_count += 1 |
|
|
print(f"\n{'='*60}") |
|
|
print(f"π ORCHESTRATION CYCLE {self.run_count}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
print("\nπ Phase 1: Job Monitoring") |
|
|
monitor_result = self.agents["monitor"].run() |
|
|
self.results["monitor"] = monitor_result |
|
|
|
|
|
|
|
|
print("\nπ Phase 2: Training Validation (HF-LLM-Trainer)") |
|
|
training_result = self.agents["training"].run() |
|
|
self.results["training"] = training_result |
|
|
|
|
|
|
|
|
if monitor_result.status in [AgentStatus.SUCCESS, AgentStatus.WAITING]: |
|
|
|
|
|
|
|
|
print("\nπ¨ Phase 3: Curation") |
|
|
curation_result = self.agents["curation"].run() |
|
|
self.results["curation"] = curation_result |
|
|
|
|
|
|
|
|
print("\nπ¦ Phase 4: Dataset Validation") |
|
|
dataset_result = self.agents["dataset"].run() |
|
|
self.results["dataset"] = dataset_result |
|
|
|
|
|
|
|
|
print("\nπ€ Phase 5: Model Verification") |
|
|
model_result = self.agents["model"].run() |
|
|
self.results["model"] = model_result |
|
|
|
|
|
|
|
|
print("\nπ Phase 6: Evaluation (Metrics & Quality)") |
|
|
curated = curation_result.data.get("examples", []) if curation_result.data else [] |
|
|
eval_result = self.agents["evaluation"].run(predictions=curated) |
|
|
self.results["evaluation"] = eval_result |
|
|
|
|
|
|
|
|
print("\nπ Phase 7: Demo Sync") |
|
|
demo_result = self.agents["demo"].run(curated) |
|
|
self.results["demo"] = demo_result |
|
|
|
|
|
return self.generate_report() |
|
|
|
|
|
def generate_report(self) -> Dict: |
|
|
"""Generate comprehensive status report.""" |
|
|
report = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"cycle": self.run_count, |
|
|
"overall_status": self._compute_overall_status(), |
|
|
"agents": {} |
|
|
} |
|
|
|
|
|
for name, result in self.results.items(): |
|
|
report["agents"][name] = { |
|
|
"status": result.status.value, |
|
|
"message": result.message |
|
|
} |
|
|
|
|
|
return report |
|
|
|
|
|
def _compute_overall_status(self) -> str: |
|
|
"""Compute overall pipeline status.""" |
|
|
statuses = [r.status for r in self.results.values()] |
|
|
|
|
|
if all(s == AgentStatus.SUCCESS for s in statuses): |
|
|
return "COMPLETE" |
|
|
elif any(s == AgentStatus.FAILED for s in statuses): |
|
|
return "NEEDS_ATTENTION" |
|
|
elif any(s == AgentStatus.WAITING for s in statuses): |
|
|
return "IN_PROGRESS" |
|
|
else: |
|
|
return "UNKNOWN" |
|
|
|
|
|
def print_summary(self, report: Dict): |
|
|
"""Print human-readable summary.""" |
|
|
print(f"\n{'='*60}") |
|
|
print("π ORCHESTRATION SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
print(f"Time: {report['timestamp']}") |
|
|
print(f"Cycle: {report['cycle']}") |
|
|
print(f"Overall: {report['overall_status']}") |
|
|
print("\nAgent Status:") |
|
|
for name, info in report["agents"].items(): |
|
|
icon = {"success": "β
", "failed": "β", "waiting": "β³", "running": "π"}.get(info["status"], "β") |
|
|
print(f" {icon} {name}: {info['status']} - {info['message'][:50]}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
print("π PitVQA Multi-Agent Orchestrator Starting...") |
|
|
|
|
|
|
|
|
job_ids = [ |
|
|
"696cfe9946affbb321046bd9", |
|
|
"696cfebf57a10a9d296ca042", |
|
|
] |
|
|
|
|
|
orchestrator = PitVQAOrchestrator(job_ids) |
|
|
|
|
|
|
|
|
report = orchestrator.run_cycle() |
|
|
orchestrator.print_summary(report) |
|
|
|
|
|
|
|
|
with open("orchestration_report.json", "w") as f: |
|
|
json.dump(report, f, indent=2) |
|
|
print(f"\nπΎ Report saved to orchestration_report.json") |
|
|
|
|
|
return report |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|