Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import os | |
| from typing import Literal, List, Optional, Dict, Any | |
| from pydantic import BaseModel | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class Observation(BaseModel): | |
| clinical_notes: str | |
| available_budget: int | |
| acquired_scans: List[str] | |
| tool_outputs: List[str] | |
| step_count: int | |
| task_id: Literal["easy", "medium", "hard"] | |
| class Action(BaseModel): | |
| tool_name: Literal["request_oct_scan", "enhance_contrast", "measure_fluid_thickness", "submit_diagnosis"] | |
| parameters: Dict[str, Any] | |
| class StepResult(BaseModel): | |
| observation: Optional[Observation] | |
| reward: float | |
| done: bool | |
| info: dict | |
| def calculate_iou(box1: List[List[int]], box2: List[List[int]]) -> float: | |
| x1_inter = max(box1[0][0], box2[0][0]) | |
| y1_inter = max(box1[0][1], box2[0][1]) | |
| x2_inter = min(box1[1][0], box2[1][0]) | |
| y2_inter = min(box1[1][1], box2[1][1]) | |
| if x1_inter >= x2_inter or y1_inter >= y2_inter: | |
| return 0.0 | |
| inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter) | |
| box1_area = (box1[1][0] - box1[0][0]) * (box1[1][1] - box1[0][1]) | |
| box2_area = (box2[1][0] - box2[0][0]) * (box2[1][1] - box2[0][1]) | |
| union_area = box1_area + box2_area - inter_area | |
| if union_area <= 0: | |
| return 0.0 | |
| return inter_area / union_area | |
| class MetaOCTEnv: | |
| def __init__(self, data_dir: str = ".", truth_file: str = "ground_truth.json", difficulty: str = "medium"): | |
| self.data_dir = data_dir | |
| with open(truth_file, "r") as f: | |
| self.ground_truth = json.load(f) | |
| self.image_files = list(self.ground_truth.keys()) | |
| self.current_idx = 0 | |
| self.max_patients = len(self.image_files) | |
| self.difficulty = difficulty.lower() | |
| if self.difficulty == "easy": | |
| self.initial_budget = 1000 | |
| elif self.difficulty == "hard": | |
| self.initial_budget = 200 | |
| else: | |
| self.initial_budget = 400 | |
| self.available_budget = self.initial_budget | |
| self.acquired_scans = [] | |
| self.tool_outputs = [] | |
| self.step_count = 0 | |
| self.max_steps = 10 | |
| self.contrast_enhanced = False | |
| def state(self) -> dict: | |
| return { | |
| "current_idx": self.current_idx, | |
| "max_patients": self.max_patients, | |
| "is_done": self.current_idx >= self.max_patients | |
| } | |
| async def reset(self) -> Observation: | |
| self.available_budget = self.initial_budget | |
| self.acquired_scans = [] | |
| self.tool_outputs = [f"Patient arrived. You have a ${self.initial_budget} diagnostic budget."] | |
| self.step_count = 0 | |
| self.contrast_enhanced = False | |
| return self._get_observation() | |
| def _get_observation(self) -> Observation: | |
| img_name = self.image_files[self.current_idx % len(self.image_files)] | |
| truth = self.ground_truth[img_name] | |
| task_id = "easy" | |
| if "CNV" in truth["label"]: task_id = "hard" | |
| elif "DME" in truth["label"] or "DRUSEN" in truth["label"]: task_id = "medium" | |
| clinical_notes = "Patient complains of blurry vision." | |
| if task_id == "easy": clinical_notes = "Routine yearly diabetic eye checkup." | |
| return Observation( | |
| clinical_notes=clinical_notes, | |
| available_budget=self.available_budget, | |
| acquired_scans=self.acquired_scans, | |
| tool_outputs=self.tool_outputs[-5:], # Keep last 5 outputs to prevent context bloat | |
| step_count=self.step_count, | |
| task_id=task_id | |
| ) | |
| async def step(self, action: Action) -> StepResult: | |
| if self.current_idx >= self.max_patients: | |
| return StepResult(observation=None, reward=0.0, done=True, info={}) | |
| self.step_count += 1 | |
| img_name = self.image_files[self.current_idx] | |
| truth = self.ground_truth[img_name] | |
| reward = 0.0 | |
| done = False | |
| info = {} | |
| if self.step_count >= self.max_steps and action.tool_name != "submit_diagnosis": | |
| done = True | |
| self.current_idx += 1 | |
| info = {"error": "Max steps reached before diagnosis"} | |
| return StepResult(observation=None, reward=-1.0, done=done, info=info) | |
| if action.tool_name == "request_oct_scan": | |
| cost = 150 | |
| if self.available_budget >= cost: | |
| self.available_budget -= cost | |
| img_path = os.path.join(self.data_dir, img_name) | |
| if img_path not in self.acquired_scans: | |
| self.acquired_scans.append(img_path) | |
| self.tool_outputs.append(f"[request_oct_scan] Success. Scan acquired at {img_path}.") | |
| else: | |
| reward -= 0.05 | |
| self.tool_outputs.append("[request_oct_scan] Warning: Scan already acquired. Wasted budget.") | |
| else: | |
| reward -= 0.1 | |
| self.tool_outputs.append("[request_oct_scan] Error: Insufficient budget.") | |
| elif action.tool_name == "enhance_contrast": | |
| cost = 50 | |
| if self.available_budget >= cost: | |
| self.available_budget -= cost | |
| if not self.acquired_scans: | |
| reward -= 0.05 | |
| self.tool_outputs.append("[enhance_contrast] Error: No scan to enhance. Request scan first.") | |
| elif self.contrast_enhanced: | |
| reward -= 0.05 | |
| self.tool_outputs.append("[enhance_contrast] Warning: Already enhanced. Wasted budget.") | |
| else: | |
| self.contrast_enhanced = True | |
| self.tool_outputs.append("[enhance_contrast] Success. Vision clarity improved by 1.2x.") | |
| else: | |
| reward -= 0.1 | |
| self.tool_outputs.append("[enhance_contrast] Error: Insufficient budget.") | |
| elif action.tool_name == "measure_fluid_thickness": | |
| cost = 200 | |
| if self.available_budget >= cost: | |
| self.available_budget -= cost | |
| if not self.acquired_scans: | |
| reward -= 0.05 | |
| self.tool_outputs.append("[measure_fluid] Error: No scan to measure. Request scan first.") | |
| else: | |
| if truth["label"] in ["CNV", "DME"]: | |
| msg = f"[measure_fluid] Abnormal retinal thickening detected. Biomarkers found: {', '.join(truth['keywords'])}" | |
| else: | |
| msg = "[measure_fluid] Normal foveal contour observed. No abnormal fluid." | |
| self.tool_outputs.append(msg) | |
| else: | |
| reward -= 0.1 | |
| self.tool_outputs.append("[measure_fluid] Error: Insufficient budget.") | |
| elif action.tool_name == "submit_diagnosis": | |
| done = True | |
| diagnosis = action.parameters.get("diagnosis", "") | |
| heatmap = action.parameters.get("heatmap_coordinates", [[0,0],[0,0]]) | |
| reasoning = action.parameters.get("reasoning", "") | |
| label_match = 1.0 if diagnosis.upper() == truth["label"].upper() else 0.0 | |
| true_box = truth["box"] | |
| iou_score = 0.0 | |
| if len(heatmap) >= 2 and len(heatmap[0]) >= 2 and len(heatmap[1]) >= 2: | |
| iou_score = calculate_iou(heatmap, true_box) | |
| if true_box[0] == [0,0] and true_box[1] == [0,0]: | |
| iou_score = 1.0 if (heatmap[0] == [0,0] and heatmap[1] == [0,0]) else 0.0 | |
| if self.contrast_enhanced: | |
| iou_score = min(1.0, iou_score * 1.2) | |
| reasoning_lower = reasoning.lower() | |
| if truth["keywords"]: | |
| matched = sum(1 for kw in truth["keywords"] if kw.lower() in reasoning_lower) | |
| reasoning_score = matched / len(truth["keywords"]) | |
| else: | |
| reasoning_score = 1.0 | |
| base_reward = (0.3 * label_match) + (0.4 * iou_score) + (0.3 * reasoning_score) | |
| budget_efficiency = max(0.2, self.available_budget / self.initial_budget) | |
| reward += (base_reward * budget_efficiency) | |
| info = { | |
| "label_match": label_match, | |
| "iou_score": iou_score, | |
| "reasoning_score": reasoning_score, | |
| "budget_efficiency": budget_efficiency, | |
| "true_label": truth["label"], | |
| "final_base_score": base_reward | |
| } | |
| self.tool_outputs.append(f"[submit_diagnosis] Evaluated. Score: {reward:.2f}") | |
| self.current_idx += 1 | |
| else: | |
| reward -= 0.1 | |
| self.tool_outputs.append(f"[{action.tool_name}] Unknown tool.") | |
| obs = self._get_observation() if not done else None | |
| return StepResult(observation=obs, reward=reward, done=done, info=info) | |
| async def close(self): | |
| pass | |