Vaish6's picture
Upload env.py
fcbd15c verified
import json
import logging
import os
from typing import Literal, List, Optional, Dict, Any
from pydantic import BaseModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Observation(BaseModel):
clinical_notes: str
available_budget: int
acquired_scans: List[str]
tool_outputs: List[str]
step_count: int
task_id: Literal["easy", "medium", "hard"]
class Action(BaseModel):
tool_name: Literal["request_oct_scan", "enhance_contrast", "measure_fluid_thickness", "submit_diagnosis"]
parameters: Dict[str, Any]
class StepResult(BaseModel):
observation: Optional[Observation]
reward: float
done: bool
info: dict
def calculate_iou(box1: List[List[int]], box2: List[List[int]]) -> float:
x1_inter = max(box1[0][0], box2[0][0])
y1_inter = max(box1[0][1], box2[0][1])
x2_inter = min(box1[1][0], box2[1][0])
y2_inter = min(box1[1][1], box2[1][1])
if x1_inter >= x2_inter or y1_inter >= y2_inter:
return 0.0
inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter)
box1_area = (box1[1][0] - box1[0][0]) * (box1[1][1] - box1[0][1])
box2_area = (box2[1][0] - box2[0][0]) * (box2[1][1] - box2[0][1])
union_area = box1_area + box2_area - inter_area
if union_area <= 0:
return 0.0
return inter_area / union_area
class MetaOCTEnv:
def __init__(self, data_dir: str = ".", truth_file: str = "ground_truth.json", difficulty: str = "medium"):
self.data_dir = data_dir
with open(truth_file, "r") as f:
self.ground_truth = json.load(f)
self.image_files = list(self.ground_truth.keys())
self.current_idx = 0
self.max_patients = len(self.image_files)
self.difficulty = difficulty.lower()
if self.difficulty == "easy":
self.initial_budget = 1000
elif self.difficulty == "hard":
self.initial_budget = 200
else:
self.initial_budget = 400
self.available_budget = self.initial_budget
self.acquired_scans = []
self.tool_outputs = []
self.step_count = 0
self.max_steps = 10
self.contrast_enhanced = False
def state(self) -> dict:
return {
"current_idx": self.current_idx,
"max_patients": self.max_patients,
"is_done": self.current_idx >= self.max_patients
}
async def reset(self) -> Observation:
self.available_budget = self.initial_budget
self.acquired_scans = []
self.tool_outputs = [f"Patient arrived. You have a ${self.initial_budget} diagnostic budget."]
self.step_count = 0
self.contrast_enhanced = False
return self._get_observation()
def _get_observation(self) -> Observation:
img_name = self.image_files[self.current_idx % len(self.image_files)]
truth = self.ground_truth[img_name]
task_id = "easy"
if "CNV" in truth["label"]: task_id = "hard"
elif "DME" in truth["label"] or "DRUSEN" in truth["label"]: task_id = "medium"
clinical_notes = "Patient complains of blurry vision."
if task_id == "easy": clinical_notes = "Routine yearly diabetic eye checkup."
return Observation(
clinical_notes=clinical_notes,
available_budget=self.available_budget,
acquired_scans=self.acquired_scans,
tool_outputs=self.tool_outputs[-5:], # Keep last 5 outputs to prevent context bloat
step_count=self.step_count,
task_id=task_id
)
async def step(self, action: Action) -> StepResult:
if self.current_idx >= self.max_patients:
return StepResult(observation=None, reward=0.0, done=True, info={})
self.step_count += 1
img_name = self.image_files[self.current_idx]
truth = self.ground_truth[img_name]
reward = 0.0
done = False
info = {}
if self.step_count >= self.max_steps and action.tool_name != "submit_diagnosis":
done = True
self.current_idx += 1
info = {"error": "Max steps reached before diagnosis"}
return StepResult(observation=None, reward=-1.0, done=done, info=info)
if action.tool_name == "request_oct_scan":
cost = 150
if self.available_budget >= cost:
self.available_budget -= cost
img_path = os.path.join(self.data_dir, img_name)
if img_path not in self.acquired_scans:
self.acquired_scans.append(img_path)
self.tool_outputs.append(f"[request_oct_scan] Success. Scan acquired at {img_path}.")
else:
reward -= 0.05
self.tool_outputs.append("[request_oct_scan] Warning: Scan already acquired. Wasted budget.")
else:
reward -= 0.1
self.tool_outputs.append("[request_oct_scan] Error: Insufficient budget.")
elif action.tool_name == "enhance_contrast":
cost = 50
if self.available_budget >= cost:
self.available_budget -= cost
if not self.acquired_scans:
reward -= 0.05
self.tool_outputs.append("[enhance_contrast] Error: No scan to enhance. Request scan first.")
elif self.contrast_enhanced:
reward -= 0.05
self.tool_outputs.append("[enhance_contrast] Warning: Already enhanced. Wasted budget.")
else:
self.contrast_enhanced = True
self.tool_outputs.append("[enhance_contrast] Success. Vision clarity improved by 1.2x.")
else:
reward -= 0.1
self.tool_outputs.append("[enhance_contrast] Error: Insufficient budget.")
elif action.tool_name == "measure_fluid_thickness":
cost = 200
if self.available_budget >= cost:
self.available_budget -= cost
if not self.acquired_scans:
reward -= 0.05
self.tool_outputs.append("[measure_fluid] Error: No scan to measure. Request scan first.")
else:
if truth["label"] in ["CNV", "DME"]:
msg = f"[measure_fluid] Abnormal retinal thickening detected. Biomarkers found: {', '.join(truth['keywords'])}"
else:
msg = "[measure_fluid] Normal foveal contour observed. No abnormal fluid."
self.tool_outputs.append(msg)
else:
reward -= 0.1
self.tool_outputs.append("[measure_fluid] Error: Insufficient budget.")
elif action.tool_name == "submit_diagnosis":
done = True
diagnosis = action.parameters.get("diagnosis", "")
heatmap = action.parameters.get("heatmap_coordinates", [[0,0],[0,0]])
reasoning = action.parameters.get("reasoning", "")
label_match = 1.0 if diagnosis.upper() == truth["label"].upper() else 0.0
true_box = truth["box"]
iou_score = 0.0
if len(heatmap) >= 2 and len(heatmap[0]) >= 2 and len(heatmap[1]) >= 2:
iou_score = calculate_iou(heatmap, true_box)
if true_box[0] == [0,0] and true_box[1] == [0,0]:
iou_score = 1.0 if (heatmap[0] == [0,0] and heatmap[1] == [0,0]) else 0.0
if self.contrast_enhanced:
iou_score = min(1.0, iou_score * 1.2)
reasoning_lower = reasoning.lower()
if truth["keywords"]:
matched = sum(1 for kw in truth["keywords"] if kw.lower() in reasoning_lower)
reasoning_score = matched / len(truth["keywords"])
else:
reasoning_score = 1.0
base_reward = (0.3 * label_match) + (0.4 * iou_score) + (0.3 * reasoning_score)
budget_efficiency = max(0.2, self.available_budget / self.initial_budget)
reward += (base_reward * budget_efficiency)
info = {
"label_match": label_match,
"iou_score": iou_score,
"reasoning_score": reasoning_score,
"budget_efficiency": budget_efficiency,
"true_label": truth["label"],
"final_base_score": base_reward
}
self.tool_outputs.append(f"[submit_diagnosis] Evaluated. Score: {reward:.2f}")
self.current_idx += 1
else:
reward -= 0.1
self.tool_outputs.append(f"[{action.tool_name}] Unknown tool.")
obs = self._get_observation() if not done else None
return StepResult(observation=obs, reward=reward, done=done, info=info)
async def close(self):
pass