mmrech commited on
Commit
e4a4f4c
Β·
verified Β·
1 Parent(s): 69375af

Upload pitvqa_agent_orchestrator.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pitvqa_agent_orchestrator.py +913 -0
pitvqa_agent_orchestrator.py ADDED
@@ -0,0 +1,913 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "huggingface_hub>=0.21.0",
6
+ # "requests",
7
+ # ]
8
+ # ///
9
+ """
10
+ PitVQA Multi-Agent Orchestration System
11
+
12
+ Specialized agents for methodologically rigorous VLM pipeline management:
13
+ 1. JobMonitorAgent - Track HuggingFace Jobs status
14
+ 2. CurationAgent - Quality-filter showcase examples
15
+ 3. DatasetAgent - Validate image-embedded dataset
16
+ 4. ModelVerifierAgent - Test merged model outputs
17
+ 5. DemoSyncAgent - Update Gradio Space with results
18
+
19
+ Run with: python pitvqa_agent_orchestrator.py
20
+ """
21
+
22
+ import os
23
+ import json
24
+ import time
25
+ from dataclasses import dataclass
26
+ from typing import Dict, List, Optional, Any
27
+ from datetime import datetime
28
+ from enum import Enum
29
+
30
+ # ============================================================
31
+ # Agent Status Types
32
+ # ============================================================
33
+
34
+ class AgentStatus(Enum):
35
+ IDLE = "idle"
36
+ RUNNING = "running"
37
+ SUCCESS = "success"
38
+ FAILED = "failed"
39
+ WAITING = "waiting"
40
+
41
+ @dataclass
42
+ class AgentResult:
43
+ agent_name: str
44
+ status: AgentStatus
45
+ message: str
46
+ data: Optional[Dict] = None
47
+ timestamp: str = ""
48
+
49
+ def __post_init__(self):
50
+ if not self.timestamp:
51
+ self.timestamp = datetime.now().isoformat()
52
+
53
+ # ============================================================
54
+ # Base Agent
55
+ # ============================================================
56
+
57
+ class BaseAgent:
58
+ """Base class for all PitVQA agents."""
59
+
60
+ def __init__(self, name: str):
61
+ self.name = name
62
+ self.status = AgentStatus.IDLE
63
+ self.results: List[AgentResult] = []
64
+
65
+ def log(self, message: str, level: str = "INFO"):
66
+ icon = {"INFO": "ℹ️", "SUCCESS": "βœ…", "ERROR": "❌", "WARN": "⚠️"}.get(level, "πŸ“Œ")
67
+ print(f"[{self.name}] {icon} {message}")
68
+
69
+ def run(self) -> AgentResult:
70
+ raise NotImplementedError
71
+
72
+ def report(self) -> Dict:
73
+ return {
74
+ "agent": self.name,
75
+ "status": self.status.value,
76
+ "results": [r.__dict__ for r in self.results]
77
+ }
78
+
79
+ # ============================================================
80
+ # Agent 1: Job Monitor
81
+ # ============================================================
82
+
83
+ class JobMonitorAgent(BaseAgent):
84
+ """Monitors HuggingFace Jobs and reports status."""
85
+
86
+ def __init__(self, job_ids: List[str]):
87
+ super().__init__("JobMonitor")
88
+ self.job_ids = job_ids
89
+ self.job_status = {}
90
+
91
+ def check_job(self, job_id: str) -> Dict:
92
+ """Check single job status using HF API."""
93
+ try:
94
+ from huggingface_hub import HfApi
95
+ api = HfApi()
96
+
97
+ # Get job info
98
+ job = api.get_job(job_id)
99
+ return {
100
+ "id": job_id,
101
+ "status": job.status.stage if hasattr(job.status, 'stage') else str(job.status),
102
+ "message": job.status.message if hasattr(job.status, 'message') else None
103
+ }
104
+ except Exception as e:
105
+ return {"id": job_id, "status": "UNKNOWN", "error": str(e)}
106
+
107
+ def run(self) -> AgentResult:
108
+ self.status = AgentStatus.RUNNING
109
+ self.log(f"Checking {len(self.job_ids)} jobs...")
110
+
111
+ all_complete = True
112
+ any_failed = False
113
+
114
+ for job_id in self.job_ids:
115
+ status = self.check_job(job_id)
116
+ self.job_status[job_id] = status
117
+
118
+ stage = status.get("status", "UNKNOWN")
119
+ self.log(f"Job {job_id[:8]}: {stage}")
120
+
121
+ if stage not in ["COMPLETED", "SUCCESS"]:
122
+ all_complete = False
123
+ if stage in ["FAILED", "ERROR"]:
124
+ any_failed = True
125
+
126
+ if any_failed:
127
+ self.status = AgentStatus.FAILED
128
+ return AgentResult(self.name, AgentStatus.FAILED, "Some jobs failed", self.job_status)
129
+ elif all_complete:
130
+ self.status = AgentStatus.SUCCESS
131
+ return AgentResult(self.name, AgentStatus.SUCCESS, "All jobs complete", self.job_status)
132
+ else:
133
+ self.status = AgentStatus.WAITING
134
+ return AgentResult(self.name, AgentStatus.WAITING, "Jobs still running", self.job_status)
135
+
136
+ # ============================================================
137
+ # Agent 2: Curation Agent
138
+ # ============================================================
139
+
140
+ class CurationAgent(BaseAgent):
141
+ """Curates showcase examples based on quality criteria."""
142
+
143
+ QUALITY_CRITERIA = {
144
+ "coordinate_validity": lambda x, y: 0 <= x <= 100 and 0 <= y <= 100,
145
+ "coordinate_diversity": lambda coords: len(set(coords)) > len(coords) * 0.5,
146
+ "video_diversity": lambda vids: len(set(vids)) >= min(5, len(vids)),
147
+ "frame_diversity": lambda frames: len(set(frames)) >= min(8, len(frames)),
148
+ }
149
+
150
+ def __init__(self, results_path: str = "./curation_review/all_results.json"):
151
+ super().__init__("Curation")
152
+ self.results_path = results_path
153
+ self.curated_examples = []
154
+
155
+ def load_results(self) -> List[Dict]:
156
+ """Load raw curation results."""
157
+ try:
158
+ with open(self.results_path) as f:
159
+ return json.load(f)
160
+ except FileNotFoundError:
161
+ self.log("Results file not found - job may still be running", "WARN")
162
+ return []
163
+
164
+ def score_example(self, example: Dict) -> float:
165
+ """Score a single example (0-1)."""
166
+ score = 0.0
167
+
168
+ # Basic validity
169
+ if example.get("success"):
170
+ score += 0.3
171
+
172
+ # Coordinate quality
173
+ if example.get("task") == "point":
174
+ x, y = example.get("x"), example.get("y")
175
+ if x and y:
176
+ # Penalize edge coordinates (likely failures)
177
+ if 10 < x < 90 and 10 < y < 90:
178
+ score += 0.3
179
+ else:
180
+ score += 0.1
181
+ elif example.get("task") == "bbox":
182
+ bbox = example.get("bbox")
183
+ if bbox and len(bbox) == 4:
184
+ # Penalize tiny or huge boxes
185
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
186
+ if 100 < area < 5000:
187
+ score += 0.3
188
+ else:
189
+ score += 0.1
190
+
191
+ # Response coherence
192
+ response = example.get("response", "")
193
+ if "<point" in response or "<box" in response:
194
+ score += 0.2
195
+
196
+ # Target relevance
197
+ target = example.get("target", "")
198
+ if target in response.lower():
199
+ score += 0.2
200
+
201
+ return min(score, 1.0)
202
+
203
+ def curate(self, results: List[Dict], top_k: int = 12) -> List[Dict]:
204
+ """Select best diverse examples."""
205
+ if not results:
206
+ return []
207
+
208
+ # Score all examples
209
+ scored = [(self.score_example(ex), ex) for ex in results if ex.get("success")]
210
+ scored.sort(key=lambda x: x[0], reverse=True)
211
+
212
+ # Ensure diversity
213
+ curated = []
214
+ used_videos = set()
215
+ used_frames = set()
216
+ used_tasks = {"point": 0, "bbox": 0}
217
+
218
+ for score, ex in scored:
219
+ if len(curated) >= top_k:
220
+ break
221
+
222
+ video = ex.get("video_id")
223
+ frame = ex.get("frame_idx")
224
+ task = ex.get("task")
225
+
226
+ # Diversity constraints
227
+ if used_videos.count(video) >= 2: # Max 2 per video
228
+ continue
229
+ if (video, frame) in used_frames: # Unique video+frame combos
230
+ continue
231
+ if used_tasks.get(task, 0) >= top_k // 2: # Balance tasks
232
+ continue
233
+
234
+ curated.append({**ex, "quality_score": score})
235
+ used_videos.add(video)
236
+ used_frames.add((video, frame))
237
+ used_tasks[task] = used_tasks.get(task, 0) + 1
238
+
239
+ return curated
240
+
241
+ def run(self) -> AgentResult:
242
+ self.status = AgentStatus.RUNNING
243
+ self.log("Loading curation results...")
244
+
245
+ results = self.load_results()
246
+ if not results:
247
+ self.status = AgentStatus.WAITING
248
+ return AgentResult(self.name, AgentStatus.WAITING, "No results available yet")
249
+
250
+ self.log(f"Scoring {len(results)} examples...")
251
+ self.curated_examples = self.curate(results)
252
+
253
+ if len(self.curated_examples) >= 8:
254
+ self.status = AgentStatus.SUCCESS
255
+
256
+ # Report diversity
257
+ videos = set(ex["video_id"] for ex in self.curated_examples)
258
+ frames = set(ex["frame_idx"] for ex in self.curated_examples)
259
+
260
+ self.log(f"Curated {len(self.curated_examples)} examples", "SUCCESS")
261
+ self.log(f" Videos: {len(videos)} unique")
262
+ self.log(f" Frames: {len(frames)} unique")
263
+
264
+ return AgentResult(
265
+ self.name,
266
+ AgentStatus.SUCCESS,
267
+ f"Curated {len(self.curated_examples)} high-quality diverse examples",
268
+ {"examples": self.curated_examples}
269
+ )
270
+ else:
271
+ self.status = AgentStatus.FAILED
272
+ return AgentResult(
273
+ self.name,
274
+ AgentStatus.FAILED,
275
+ f"Only {len(self.curated_examples)} examples passed quality checks"
276
+ )
277
+
278
+ # ============================================================
279
+ # Agent 3: Dataset Validator
280
+ # ============================================================
281
+
282
+ class DatasetValidatorAgent(BaseAgent):
283
+ """Validates image-embedded dataset quality."""
284
+
285
+ def __init__(self, dataset_id: str = "mmrech/pitvqa-spatial-with-images"):
286
+ super().__init__("DatasetValidator")
287
+ self.dataset_id = dataset_id
288
+
289
+ def run(self) -> AgentResult:
290
+ self.status = AgentStatus.RUNNING
291
+ self.log(f"Validating dataset: {self.dataset_id}")
292
+
293
+ try:
294
+ from datasets import load_dataset
295
+
296
+ # Try to load dataset
297
+ ds = load_dataset(self.dataset_id, split="train[:10]")
298
+
299
+ # Check required fields
300
+ required_fields = ["image", "messages"]
301
+ missing = [f for f in required_fields if f not in ds.features]
302
+
303
+ if missing:
304
+ self.status = AgentStatus.FAILED
305
+ return AgentResult(
306
+ self.name,
307
+ AgentStatus.FAILED,
308
+ f"Missing fields: {missing}"
309
+ )
310
+
311
+ # Validate image quality
312
+ valid_images = 0
313
+ for ex in ds:
314
+ img = ex.get("image")
315
+ if img and hasattr(img, "size") and img.size[0] > 0:
316
+ valid_images += 1
317
+
318
+ if valid_images == len(ds):
319
+ self.status = AgentStatus.SUCCESS
320
+ return AgentResult(
321
+ self.name,
322
+ AgentStatus.SUCCESS,
323
+ f"Dataset valid: {valid_images}/{len(ds)} images OK",
324
+ {"sample_count": len(ds), "valid_images": valid_images}
325
+ )
326
+ else:
327
+ self.status = AgentStatus.FAILED
328
+ return AgentResult(
329
+ self.name,
330
+ AgentStatus.FAILED,
331
+ f"Invalid images: {len(ds) - valid_images}/{len(ds)}"
332
+ )
333
+
334
+ except Exception as e:
335
+ self.status = AgentStatus.WAITING
336
+ return AgentResult(
337
+ self.name,
338
+ AgentStatus.WAITING,
339
+ f"Dataset not yet available: {e}"
340
+ )
341
+
342
+ # ============================================================
343
+ # Agent 4: Model Verifier
344
+ # ============================================================
345
+
346
+ class ModelVerifierAgent(BaseAgent):
347
+ """Verifies merged model outputs are correct."""
348
+
349
+ TEST_PROMPTS = [
350
+ ("Point to the suction device", "point"),
351
+ ("Draw a bounding box around the surgical instrument", "bbox"),
352
+ ("What surgical phase is this?", "classification"),
353
+ ]
354
+
355
+ def __init__(self, model_id: str = "mmrech/pitvqa-qwen2vl-merged"):
356
+ super().__init__("ModelVerifier")
357
+ self.model_id = model_id
358
+
359
+ def run(self) -> AgentResult:
360
+ self.status = AgentStatus.RUNNING
361
+ self.log(f"Verifying model: {self.model_id}")
362
+
363
+ try:
364
+ from huggingface_hub import HfApi
365
+ api = HfApi()
366
+
367
+ # Check if model exists
368
+ try:
369
+ info = api.model_info(self.model_id)
370
+ self.log(f"Model found: {info.modelId}")
371
+
372
+ # Check for required files
373
+ files = [f.rfilename for f in info.siblings]
374
+ required = ["config.json", "model.safetensors"]
375
+
376
+ # Check if main model files exist
377
+ has_model = any("safetensors" in f or "pytorch" in f for f in files)
378
+ has_config = "config.json" in files
379
+
380
+ if has_model and has_config:
381
+ self.status = AgentStatus.SUCCESS
382
+ return AgentResult(
383
+ self.name,
384
+ AgentStatus.SUCCESS,
385
+ f"Model verified: {len(files)} files present",
386
+ {"files": files[:10]} # First 10 files
387
+ )
388
+ else:
389
+ self.status = AgentStatus.FAILED
390
+ return AgentResult(
391
+ self.name,
392
+ AgentStatus.FAILED,
393
+ f"Missing model files (has_model={has_model}, has_config={has_config})"
394
+ )
395
+
396
+ except Exception as e:
397
+ self.status = AgentStatus.WAITING
398
+ return AgentResult(
399
+ self.name,
400
+ AgentStatus.WAITING,
401
+ f"Model not yet available: {e}"
402
+ )
403
+
404
+ except Exception as e:
405
+ self.status = AgentStatus.FAILED
406
+ return AgentResult(self.name, AgentStatus.FAILED, f"Error: {e}")
407
+
408
+ # ============================================================
409
+ # Agent 5: Training Specialist (HF-LLM-Trainer)
410
+ # ============================================================
411
+
412
+ class TrainingSpecialistAgent(BaseAgent):
413
+ """
414
+ Specialist in HuggingFace LLM Training (TRL/SFT/LoRA/DPO).
415
+
416
+ Responsibilities:
417
+ - Validate training configurations
418
+ - Check adapter quality
419
+ - Recommend training improvements
420
+ - Verify LoRA/PEFT setup
421
+ """
422
+
423
+ TRAINING_METHODS = {
424
+ "SFT": "Supervised Fine-Tuning - learning from (input, output) pairs",
425
+ "LoRA": "Low-Rank Adaptation - parameter-efficient adapters",
426
+ "DPO": "Direct Preference Optimization - learning from preferences",
427
+ "RLHF": "Reinforcement Learning from Human Feedback",
428
+ }
429
+
430
+ OPTIMAL_CONFIG = {
431
+ "lora_r": 16,
432
+ "lora_alpha": 32,
433
+ "learning_rate": 1e-4,
434
+ "batch_size": 1,
435
+ "gradient_accumulation_steps": 16,
436
+ "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"],
437
+ }
438
+
439
+ def __init__(self, adapter_repo: str = "mmrech/pitvqa-qwen2vl-unified-v2"):
440
+ super().__init__("TrainingSpecialist")
441
+ self.adapter_repo = adapter_repo
442
+
443
+ def validate_adapter_config(self) -> Dict:
444
+ """Validate adapter configuration."""
445
+ try:
446
+ from huggingface_hub import hf_hub_download
447
+ import json
448
+
449
+ # Download adapter config
450
+ config_path = hf_hub_download(
451
+ repo_id=self.adapter_repo,
452
+ filename="stage4/adapter_config.json"
453
+ )
454
+
455
+ with open(config_path) as f:
456
+ config = json.load(f)
457
+
458
+ # Check key parameters
459
+ issues = []
460
+ recommendations = []
461
+
462
+ # Check LoRA rank
463
+ if config.get("r", 0) < 8:
464
+ issues.append("LoRA rank too low (r < 8)")
465
+ elif config.get("r", 0) > 64:
466
+ recommendations.append("Consider reducing LoRA rank for efficiency")
467
+
468
+ # Check target modules
469
+ target_modules = config.get("target_modules", [])
470
+ if not any("proj" in m for m in target_modules):
471
+ issues.append("No projection layers targeted")
472
+
473
+ return {
474
+ "config": config,
475
+ "issues": issues,
476
+ "recommendations": recommendations,
477
+ "valid": len(issues) == 0
478
+ }
479
+
480
+ except Exception as e:
481
+ return {"error": str(e), "valid": False}
482
+
483
+ def recommend_next_training(self, current_metrics: Dict = None) -> Dict:
484
+ """Recommend next training steps based on current metrics."""
485
+ recommendations = []
486
+
487
+ if not current_metrics:
488
+ recommendations.append({
489
+ "priority": "HIGH",
490
+ "action": "Run evaluation to get baseline metrics",
491
+ "method": "scripts/evaluate_unified_vlm.py"
492
+ })
493
+ else:
494
+ accuracy = current_metrics.get("accuracy", 0)
495
+
496
+ if accuracy < 0.7:
497
+ recommendations.append({
498
+ "priority": "HIGH",
499
+ "action": "Increase training epochs or data",
500
+ "method": "SFT with more epochs"
501
+ })
502
+
503
+ if accuracy >= 0.7 and accuracy < 0.85:
504
+ recommendations.append({
505
+ "priority": "MEDIUM",
506
+ "action": "Consider DPO for preference learning",
507
+ "method": "Create chosen/rejected pairs from predictions"
508
+ })
509
+
510
+ if accuracy >= 0.85:
511
+ recommendations.append({
512
+ "priority": "LOW",
513
+ "action": "Model performing well - focus on inference optimization",
514
+ "method": "Merge adapters, quantize for deployment"
515
+ })
516
+
517
+ return {"recommendations": recommendations}
518
+
519
+ def run(self) -> AgentResult:
520
+ self.status = AgentStatus.RUNNING
521
+ self.log(f"Validating training setup: {self.adapter_repo}")
522
+
523
+ # Validate adapter
524
+ validation = self.validate_adapter_config()
525
+
526
+ if validation.get("valid"):
527
+ self.status = AgentStatus.SUCCESS
528
+ recommendations = self.recommend_next_training()
529
+
530
+ return AgentResult(
531
+ self.name,
532
+ AgentStatus.SUCCESS,
533
+ f"Training config valid. LoRA r={validation['config'].get('r')}",
534
+ {
535
+ "config": validation["config"],
536
+ "recommendations": recommendations["recommendations"]
537
+ }
538
+ )
539
+ elif validation.get("error"):
540
+ self.status = AgentStatus.WAITING
541
+ return AgentResult(
542
+ self.name,
543
+ AgentStatus.WAITING,
544
+ f"Could not load adapter: {validation['error']}"
545
+ )
546
+ else:
547
+ self.status = AgentStatus.FAILED
548
+ return AgentResult(
549
+ self.name,
550
+ AgentStatus.FAILED,
551
+ f"Issues found: {validation['issues']}",
552
+ validation
553
+ )
554
+
555
+ # ============================================================
556
+ # Agent 6: Evaluation Specialist
557
+ # ============================================================
558
+
559
+ class EvaluationSpecialistAgent(BaseAgent):
560
+ """
561
+ Specialist in Model Evaluation (metrics, benchmarks, validation).
562
+
563
+ Responsibilities:
564
+ - Compute accuracy, F1, precision, recall
565
+ - Validate coordinate predictions (MAE, quadrant accuracy)
566
+ - Compare against baselines
567
+ - Generate evaluation reports
568
+ """
569
+
570
+ METRICS = {
571
+ "classification": ["accuracy", "f1", "precision", "recall"],
572
+ "localization": ["mae", "quadrant_accuracy", "distance_error"],
573
+ "detection": ["iou", "ap", "ar"],
574
+ }
575
+
576
+ THRESHOLDS = {
577
+ "quadrant_accuracy": 0.75, # Minimum acceptable
578
+ "mae": 15.0, # Maximum acceptable (percentage)
579
+ "classification_accuracy": 0.80,
580
+ }
581
+
582
+ def __init__(self, model_repo: str = "mmrech/pitvqa-qwen2vl-unified-v2"):
583
+ super().__init__("EvaluationSpecialist")
584
+ self.model_repo = model_repo
585
+ self.metrics = {}
586
+
587
+ def load_evaluation_results(self) -> Dict:
588
+ """Load existing evaluation results if available."""
589
+ try:
590
+ with open("evaluation_results.json") as f:
591
+ return json.load(f)
592
+ except FileNotFoundError:
593
+ return {}
594
+
595
+ def compute_quick_metrics(self, predictions: List[Dict]) -> Dict:
596
+ """Compute quick metrics from predictions."""
597
+ if not predictions:
598
+ return {}
599
+
600
+ metrics = {}
601
+
602
+ # Coordinate predictions
603
+ coord_preds = [p for p in predictions if p.get("task") in ["point", "pointing"]]
604
+ if coord_preds:
605
+ valid = [p for p in coord_preds if p.get("x") is not None]
606
+ metrics["valid_rate"] = len(valid) / len(coord_preds)
607
+
608
+ # Calculate MAE if ground truth available
609
+ errors = []
610
+ for p in valid:
611
+ if p.get("gt_x") and p.get("gt_y"):
612
+ err = ((p["x"] - p["gt_x"])**2 + (p["y"] - p["gt_y"])**2)**0.5
613
+ errors.append(err)
614
+
615
+ if errors:
616
+ metrics["mae"] = sum(errors) / len(errors)
617
+ metrics["quadrant_accuracy"] = sum(1 for e in errors if e < 25) / len(errors)
618
+
619
+ # Classification predictions
620
+ class_preds = [p for p in predictions if p.get("task") == "classification"]
621
+ if class_preds:
622
+ correct = sum(1 for p in class_preds if p.get("prediction") == p.get("ground_truth"))
623
+ metrics["classification_accuracy"] = correct / len(class_preds)
624
+
625
+ return metrics
626
+
627
+ def evaluate_against_thresholds(self, metrics: Dict) -> Dict:
628
+ """Check metrics against quality thresholds."""
629
+ results = {"passed": [], "failed": [], "warnings": []}
630
+
631
+ for metric, threshold in self.THRESHOLDS.items():
632
+ if metric in metrics:
633
+ value = metrics[metric]
634
+ if metric == "mae":
635
+ passed = value <= threshold
636
+ else:
637
+ passed = value >= threshold
638
+
639
+ entry = {"metric": metric, "value": value, "threshold": threshold}
640
+ if passed:
641
+ results["passed"].append(entry)
642
+ else:
643
+ results["failed"].append(entry)
644
+
645
+ return results
646
+
647
+ def generate_report(self, metrics: Dict, threshold_results: Dict) -> str:
648
+ """Generate evaluation report."""
649
+ report = []
650
+ report.append("=" * 50)
651
+ report.append("EVALUATION REPORT")
652
+ report.append("=" * 50)
653
+
654
+ report.append("\nπŸ“Š METRICS:")
655
+ for k, v in metrics.items():
656
+ report.append(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}")
657
+
658
+ report.append("\nβœ… PASSED:")
659
+ for item in threshold_results["passed"]:
660
+ report.append(f" {item['metric']}: {item['value']:.4f} (threshold: {item['threshold']})")
661
+
662
+ if threshold_results["failed"]:
663
+ report.append("\n❌ FAILED:")
664
+ for item in threshold_results["failed"]:
665
+ report.append(f" {item['metric']}: {item['value']:.4f} (threshold: {item['threshold']})")
666
+
667
+ return "\n".join(report)
668
+
669
+ def run(self, predictions: List[Dict] = None) -> AgentResult:
670
+ self.status = AgentStatus.RUNNING
671
+ self.log("Running evaluation...")
672
+
673
+ # Try to load existing results
674
+ existing = self.load_evaluation_results()
675
+
676
+ if existing:
677
+ self.log("Found existing evaluation results")
678
+ self.metrics = existing
679
+ elif predictions:
680
+ self.log(f"Computing metrics from {len(predictions)} predictions")
681
+ self.metrics = self.compute_quick_metrics(predictions)
682
+ else:
683
+ self.status = AgentStatus.WAITING
684
+ return AgentResult(
685
+ self.name,
686
+ AgentStatus.WAITING,
687
+ "No predictions available for evaluation"
688
+ )
689
+
690
+ # Check against thresholds
691
+ threshold_results = self.evaluate_against_thresholds(self.metrics)
692
+
693
+ # Generate report
694
+ report = self.generate_report(self.metrics, threshold_results)
695
+ self.log(f"\n{report}")
696
+
697
+ if threshold_results["failed"]:
698
+ self.status = AgentStatus.FAILED
699
+ return AgentResult(
700
+ self.name,
701
+ AgentStatus.FAILED,
702
+ f"{len(threshold_results['failed'])} metrics below threshold",
703
+ {"metrics": self.metrics, "thresholds": threshold_results}
704
+ )
705
+ else:
706
+ self.status = AgentStatus.SUCCESS
707
+ return AgentResult(
708
+ self.name,
709
+ AgentStatus.SUCCESS,
710
+ f"All {len(threshold_results['passed'])} metrics passed",
711
+ {"metrics": self.metrics, "thresholds": threshold_results}
712
+ )
713
+
714
+ # ============================================================
715
+ # Agent 7: Demo Sync Agent
716
+ # ============================================================
717
+
718
+ class DemoSyncAgent(BaseAgent):
719
+ """Syncs curated examples to Gradio Space."""
720
+
721
+ def __init__(self, space_id: str = "mmrech/pitvqa-surgical-vlm"):
722
+ super().__init__("DemoSync")
723
+ self.space_id = space_id
724
+
725
+ def run(self, curated_examples: List[Dict] = None) -> AgentResult:
726
+ self.status = AgentStatus.RUNNING
727
+ self.log(f"Syncing to Space: {self.space_id}")
728
+
729
+ if not curated_examples:
730
+ self.status = AgentStatus.WAITING
731
+ return AgentResult(
732
+ self.name,
733
+ AgentStatus.WAITING,
734
+ "No curated examples to sync"
735
+ )
736
+
737
+ try:
738
+ from huggingface_hub import HfApi
739
+ api = HfApi()
740
+
741
+ # Check Space status
742
+ try:
743
+ info = api.space_info(self.space_id)
744
+ runtime = info.runtime
745
+
746
+ if runtime and runtime.stage == "RUNNING":
747
+ self.log(f"Space is running", "SUCCESS")
748
+
749
+ # Create examples JSON for sync
750
+ examples_json = json.dumps(curated_examples, indent=2)
751
+
752
+ self.status = AgentStatus.SUCCESS
753
+ return AgentResult(
754
+ self.name,
755
+ AgentStatus.SUCCESS,
756
+ f"Space running, {len(curated_examples)} examples ready for sync",
757
+ {"space_status": "RUNNING", "examples_count": len(curated_examples)}
758
+ )
759
+ else:
760
+ self.status = AgentStatus.WAITING
761
+ return AgentResult(
762
+ self.name,
763
+ AgentStatus.WAITING,
764
+ f"Space not running: {runtime.stage if runtime else 'unknown'}"
765
+ )
766
+
767
+ except Exception as e:
768
+ self.status = AgentStatus.FAILED
769
+ return AgentResult(self.name, AgentStatus.FAILED, f"Space error: {e}")
770
+
771
+ except Exception as e:
772
+ self.status = AgentStatus.FAILED
773
+ return AgentResult(self.name, AgentStatus.FAILED, f"Error: {e}")
774
+
775
+ # ============================================================
776
+ # Orchestrator
777
+ # ============================================================
778
+
779
+ class PitVQAOrchestrator:
780
+ """Coordinates all agents for the PitVQA pipeline."""
781
+
782
+ def __init__(self, job_ids: List[str]):
783
+ self.agents = {
784
+ "monitor": JobMonitorAgent(job_ids),
785
+ "curation": CurationAgent(),
786
+ "dataset": DatasetValidatorAgent(),
787
+ "model": ModelVerifierAgent(),
788
+ "training": TrainingSpecialistAgent(), # HF-LLM-Trainer specialist
789
+ "evaluation": EvaluationSpecialistAgent(), # Eval-Model specialist
790
+ "demo": DemoSyncAgent(),
791
+ }
792
+ self.results = {}
793
+ self.run_count = 0
794
+
795
+ def run_cycle(self) -> Dict:
796
+ """Run one orchestration cycle."""
797
+ self.run_count += 1
798
+ print(f"\n{'='*60}")
799
+ print(f"πŸ”„ ORCHESTRATION CYCLE {self.run_count}")
800
+ print(f"{'='*60}")
801
+
802
+ # Phase 1: Check job status
803
+ print("\nπŸ“Š Phase 1: Job Monitoring")
804
+ monitor_result = self.agents["monitor"].run()
805
+ self.results["monitor"] = monitor_result
806
+
807
+ # Phase 2: Training Specialist - Validate adapter config
808
+ print("\nπŸŽ“ Phase 2: Training Validation (HF-LLM-Trainer)")
809
+ training_result = self.agents["training"].run()
810
+ self.results["training"] = training_result
811
+
812
+ # Phase 3: If jobs complete, run downstream agents
813
+ if monitor_result.status in [AgentStatus.SUCCESS, AgentStatus.WAITING]:
814
+
815
+ # Run curation
816
+ print("\n🎨 Phase 3: Curation")
817
+ curation_result = self.agents["curation"].run()
818
+ self.results["curation"] = curation_result
819
+
820
+ # Run dataset validation
821
+ print("\nπŸ“¦ Phase 4: Dataset Validation")
822
+ dataset_result = self.agents["dataset"].run()
823
+ self.results["dataset"] = dataset_result
824
+
825
+ # Run model verification
826
+ print("\nπŸ€– Phase 5: Model Verification")
827
+ model_result = self.agents["model"].run()
828
+ self.results["model"] = model_result
829
+
830
+ # Run evaluation specialist
831
+ print("\nπŸ“ˆ Phase 6: Evaluation (Metrics & Quality)")
832
+ curated = curation_result.data.get("examples", []) if curation_result.data else []
833
+ eval_result = self.agents["evaluation"].run(predictions=curated)
834
+ self.results["evaluation"] = eval_result
835
+
836
+ # Run demo sync if curation succeeded
837
+ print("\n🌐 Phase 7: Demo Sync")
838
+ demo_result = self.agents["demo"].run(curated)
839
+ self.results["demo"] = demo_result
840
+
841
+ return self.generate_report()
842
+
843
+ def generate_report(self) -> Dict:
844
+ """Generate comprehensive status report."""
845
+ report = {
846
+ "timestamp": datetime.now().isoformat(),
847
+ "cycle": self.run_count,
848
+ "overall_status": self._compute_overall_status(),
849
+ "agents": {}
850
+ }
851
+
852
+ for name, result in self.results.items():
853
+ report["agents"][name] = {
854
+ "status": result.status.value,
855
+ "message": result.message
856
+ }
857
+
858
+ return report
859
+
860
+ def _compute_overall_status(self) -> str:
861
+ """Compute overall pipeline status."""
862
+ statuses = [r.status for r in self.results.values()]
863
+
864
+ if all(s == AgentStatus.SUCCESS for s in statuses):
865
+ return "COMPLETE"
866
+ elif any(s == AgentStatus.FAILED for s in statuses):
867
+ return "NEEDS_ATTENTION"
868
+ elif any(s == AgentStatus.WAITING for s in statuses):
869
+ return "IN_PROGRESS"
870
+ else:
871
+ return "UNKNOWN"
872
+
873
+ def print_summary(self, report: Dict):
874
+ """Print human-readable summary."""
875
+ print(f"\n{'='*60}")
876
+ print("πŸ“‹ ORCHESTRATION SUMMARY")
877
+ print(f"{'='*60}")
878
+ print(f"Time: {report['timestamp']}")
879
+ print(f"Cycle: {report['cycle']}")
880
+ print(f"Overall: {report['overall_status']}")
881
+ print("\nAgent Status:")
882
+ for name, info in report["agents"].items():
883
+ icon = {"success": "βœ…", "failed": "❌", "waiting": "⏳", "running": "πŸ”„"}.get(info["status"], "❓")
884
+ print(f" {icon} {name}: {info['status']} - {info['message'][:50]}")
885
+
886
+ # ============================================================
887
+ # Main
888
+ # ============================================================
889
+
890
+ def main():
891
+ print("πŸš€ PitVQA Multi-Agent Orchestrator Starting...")
892
+
893
+ # Current job IDs
894
+ job_ids = [
895
+ "696cfe9946affbb321046bd9", # Curation job
896
+ "696cfebf57a10a9d296ca042", # Merge job
897
+ ]
898
+
899
+ orchestrator = PitVQAOrchestrator(job_ids)
900
+
901
+ # Run orchestration cycle
902
+ report = orchestrator.run_cycle()
903
+ orchestrator.print_summary(report)
904
+
905
+ # Save report
906
+ with open("orchestration_report.json", "w") as f:
907
+ json.dump(report, f, indent=2)
908
+ print(f"\nπŸ’Ύ Report saved to orchestration_report.json")
909
+
910
+ return report
911
+
912
+ if __name__ == "__main__":
913
+ main()