mimic-svm / code /generate_val_outputs.py
ahmad4raza's picture
Upload folder using huggingface_hub
097b6c6 verified
import os
os.environ["CUDA_AVAILABLE_DEVICES"] = "6"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
from unsloth import FastVisionModel
import json
from pathlib import Path
from typing import Dict, List, Optional
import torch
from PIL import Image
from transformers import AutoProcessor
CHECKPOINT_PATH = Path("outputs/mimic_qwen3vl_lora_8bit_5/checkpoint-17454")
BASE_MODEL_NAME = "unsloth/Qwen3-VL-8B-Thinking"
SYSTEM_PROMPT_PATH = Path(__file__).with_name("new_system_prompt.txt")
VAL_ROOT = Path("dataset/val")
OUTPUT_DIR = Path("dataset/val_outputs")
MAX_NEW_TOKENS = 4096
TEMPERATURE = 0.0
SEED = 3407
MAX_IMAGES_PER_STUDY = 0
LOAD_IN_4BIT = False
LOAD_IN_8BIT = True
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}
def clean_text(text: str) -> str:
lines = [line.strip() for line in text.splitlines() if line.strip()]
return "\n".join(lines).strip()
def extract_findings_impression(text: str) -> Optional[str]:
import re
text = text.replace("\r\n", "\n").replace("\r", "\n")
boundary = re.compile(r"^[ \t]*([A-Za-z][A-Za-z ()/]{3,}):[ \t]*", re.MULTILINE)
bounds: list = []
for m in boundary.finditer(text):
name = m.group(1).strip().upper()
if "PROVISIONAL" in name:
continue
bounds.append((name, m.start(), m.end()))
found: dict = {}
for i, (name, _header_start, content_start) in enumerate(bounds):
if name == "FINDINGS":
key = "Findings"
elif name == "IMPRESSION":
key = "Impression"
else:
continue
if key in found:
continue
content_end = bounds[i + 1][1] if i + 1 < len(bounds) else len(text)
raw = text[content_start:content_end]
lines = [
line.strip()
for line in raw.splitlines()
if line.strip() and not re.match(r"^[_\-=]{5,}$", line.strip())
]
content = "\n".join(lines)
if content:
found[key] = content
if not found:
return None
parts: list = []
for heading in ("Findings", "Impression"):
if heading in found:
parts.append(f"{heading}:\n{found[heading]}")
return "\n\n".join(parts) if parts else None
def clean_report_text(text: str) -> str:
extracted = extract_findings_impression(text)
if extracted:
return extracted.strip()
return clean_text(text)
def load_images(image_paths: List[Path]) -> List[Image.Image]:
images: List[Image.Image] = []
for image_path in image_paths:
try:
with Image.open(image_path) as img:
images.append(img.convert("RGB"))
except Exception as error:
print(f"[warn] Skipping unreadable image {image_path}: {error}")
return images
def build_messages(images: List[Image.Image], prompt_text: str) -> List[Dict]:
# Keep the message format consistent with base-inference.py:
# one user turn with image blocks first and one text block last.
user_content: List[Dict] = []
user_content.extend({"type": "image", "image": image} for image in images)
user_content.append({"type": "text", "text": prompt_text})
return [{"role": "user", "content": user_content}]
def generate_report(
model,
processor,
image_paths: List[Path],
prompt_text: str,
max_new_tokens: int,
temperature: float,
) -> str:
images = load_images(image_paths)
if not images:
return "<no readable images>"
messages = build_messages(images, prompt_text)
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
device = next(model.parameters()).device
inputs = inputs.to(device)
generate_kwargs = {"max_new_tokens": max_new_tokens}
if temperature > 0:
generate_kwargs["do_sample"] = True
generate_kwargs["temperature"] = temperature
with torch.inference_mode():
outputs = model.generate(
**inputs,
**generate_kwargs,
)
if "attention_mask" in inputs:
input_len = int(inputs["attention_mask"][0].sum().item())
else:
input_len = int(inputs["input_ids"].shape[-1])
text = processor.decode(
outputs[0][input_len:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return clean_report_text(text)
def load_processor(checkpoint_path: Path, base_model: str):
# Match working inference script: always load the full multimodal processor
# from the base model, not from LoRA checkpoint folders.
return AutoProcessor.from_pretrained(base_model, trust_remote_code=True)
def is_completed_report(report_path: Path) -> bool:
if not report_path.exists():
return False
try:
text = report_path.read_text(encoding="utf-8", errors="ignore").strip()
except OSError:
return False
if not text:
return False
# Treat prior placeholder failures as incomplete so they get regenerated.
if text.startswith("<tokenization failed:") or text == "<no readable images>":
return False
return True
def load_existing_manifest_ids(manifest_path: Path) -> set:
existing_ids = set()
if not manifest_path.exists():
return existing_ids
try:
with manifest_path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
except json.JSONDecodeError:
continue
study_id = row.get("study_id")
if study_id:
existing_ids.add(str(study_id))
except OSError:
pass
return existing_ids
def get_study_ids(val_root: Path) -> List[str]:
files_dir = val_root / "files"
if not files_dir.is_dir():
raise FileNotFoundError(f"Missing directory: {files_dir}")
return sorted(path.stem for path in files_dir.glob("*.txt"))
def get_study_images(val_root: Path, study_id: str, max_images_per_study: int) -> List[Path]:
study_dir = val_root / "images" / study_id
if not study_dir.is_dir():
return []
image_paths = sorted(
p for p in study_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTENSIONS
)
if max_images_per_study > 0:
image_paths = image_paths[:max_images_per_study]
return image_paths
def main() -> None:
if not VAL_ROOT.is_dir():
raise FileNotFoundError(f"val_root not found: {VAL_ROOT}")
if not CHECKPOINT_PATH.exists():
raise FileNotFoundError(f"checkpoint not found: {CHECKPOINT_PATH}")
if not SYSTEM_PROMPT_PATH.exists():
raise FileNotFoundError(f"system prompt file not found: {SYSTEM_PROMPT_PATH}")
system_prompt = SYSTEM_PROMPT_PATH.read_text(encoding="utf-8").strip()
torch.manual_seed(SEED)
print(f"Loading model from {CHECKPOINT_PATH}")
model, _ = FastVisionModel.from_pretrained(
model_name=str(CHECKPOINT_PATH),
load_in_4bit=LOAD_IN_4BIT,
load_in_8bit=LOAD_IN_8BIT,
)
processor = load_processor(CHECKPOINT_PATH, BASE_MODEL_NAME)
FastVisionModel.for_inference(model)
out_reports = OUTPUT_DIR / "reports"
out_reports.mkdir(parents=True, exist_ok=True)
manifest_path = OUTPUT_DIR / "predictions.jsonl"
manifest_ids = load_existing_manifest_ids(manifest_path)
study_ids = get_study_ids(VAL_ROOT)
written = 0
skipped_existing = 0
failed = 0
with manifest_path.open("a", encoding="utf-8") as manifest:
for idx, study_id in enumerate(study_ids, start=1):
image_paths = get_study_images(VAL_ROOT, study_id, MAX_IMAGES_PER_STUDY)
if not image_paths:
print(f"[{idx}/{len(study_ids)}] {study_id}: no images, skipped")
continue
pred_path = out_reports / f"{study_id}.txt"
if is_completed_report(pred_path):
skipped_existing += 1
if study_id not in manifest_ids:
row = {
"study_id": study_id,
"image_count": len(image_paths),
"image_paths": [str(p) for p in image_paths],
"prediction_path": str(pred_path),
}
manifest.write(json.dumps(row, ensure_ascii=False) + "\n")
manifest_ids.add(study_id)
print(f"[{idx}/{len(study_ids)}] {study_id}: already done, skipping")
continue
try:
pred = generate_report(
model=model,
processor=processor,
image_paths=image_paths,
prompt_text=system_prompt,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
)
except torch.OutOfMemoryError as error:
failed += 1
print(f"[{idx}/{len(study_ids)}] {study_id}: OOM, skipped ({error})")
if torch.cuda.is_available():
torch.cuda.empty_cache()
continue
except Exception as error:
failed += 1
print(f"[{idx}/{len(study_ids)}] {study_id}: failed, skipped ({error})")
if torch.cuda.is_available():
torch.cuda.empty_cache()
continue
pred_path.write_text(pred + "\n", encoding="utf-8")
if study_id not in manifest_ids:
row = {
"study_id": study_id,
"image_count": len(image_paths),
"image_paths": [str(p) for p in image_paths],
"prediction_path": str(pred_path),
}
manifest.write(json.dumps(row, ensure_ascii=False) + "\n")
manifest_ids.add(study_id)
written += 1
print(f"[{idx}/{len(study_ids)}] {study_id}: done ({len(image_paths)} image(s))")
print("=" * 60)
print(f"Newly processed: {written}")
print(f"Already completed (skipped): {skipped_existing}")
print(f"Failed this run: {failed}")
print(f"Predictions dir: {out_reports}")
print(f"Manifest: {manifest_path}")
if __name__ == "__main__":
main()