import os os.environ["CUDA_AVAILABLE_DEVICES"] = "6" os.environ["CUDA_VISIBLE_DEVICES"] = "6" from unsloth import FastVisionModel import json from pathlib import Path from typing import Dict, List, Optional import torch from PIL import Image from transformers import AutoProcessor CHECKPOINT_PATH = Path("outputs/mimic_qwen3vl_lora_8bit_5/checkpoint-17454") BASE_MODEL_NAME = "unsloth/Qwen3-VL-8B-Thinking" SYSTEM_PROMPT_PATH = Path(__file__).with_name("new_system_prompt.txt") VAL_ROOT = Path("dataset/val") OUTPUT_DIR = Path("dataset/val_outputs") MAX_NEW_TOKENS = 4096 TEMPERATURE = 0.0 SEED = 3407 MAX_IMAGES_PER_STUDY = 0 LOAD_IN_4BIT = False LOAD_IN_8BIT = True IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"} def clean_text(text: str) -> str: lines = [line.strip() for line in text.splitlines() if line.strip()] return "\n".join(lines).strip() def extract_findings_impression(text: str) -> Optional[str]: import re text = text.replace("\r\n", "\n").replace("\r", "\n") boundary = re.compile(r"^[ \t]*([A-Za-z][A-Za-z ()/]{3,}):[ \t]*", re.MULTILINE) bounds: list = [] for m in boundary.finditer(text): name = m.group(1).strip().upper() if "PROVISIONAL" in name: continue bounds.append((name, m.start(), m.end())) found: dict = {} for i, (name, _header_start, content_start) in enumerate(bounds): if name == "FINDINGS": key = "Findings" elif name == "IMPRESSION": key = "Impression" else: continue if key in found: continue content_end = bounds[i + 1][1] if i + 1 < len(bounds) else len(text) raw = text[content_start:content_end] lines = [ line.strip() for line in raw.splitlines() if line.strip() and not re.match(r"^[_\-=]{5,}$", line.strip()) ] content = "\n".join(lines) if content: found[key] = content if not found: return None parts: list = [] for heading in ("Findings", "Impression"): if heading in found: parts.append(f"{heading}:\n{found[heading]}") return "\n\n".join(parts) if parts else None def clean_report_text(text: str) -> str: extracted = extract_findings_impression(text) if extracted: return extracted.strip() return clean_text(text) def load_images(image_paths: List[Path]) -> List[Image.Image]: images: List[Image.Image] = [] for image_path in image_paths: try: with Image.open(image_path) as img: images.append(img.convert("RGB")) except Exception as error: print(f"[warn] Skipping unreadable image {image_path}: {error}") return images def build_messages(images: List[Image.Image], prompt_text: str) -> List[Dict]: # Keep the message format consistent with base-inference.py: # one user turn with image blocks first and one text block last. user_content: List[Dict] = [] user_content.extend({"type": "image", "image": image} for image in images) user_content.append({"type": "text", "text": prompt_text}) return [{"role": "user", "content": user_content}] def generate_report( model, processor, image_paths: List[Path], prompt_text: str, max_new_tokens: int, temperature: float, ) -> str: images = load_images(image_paths) if not images: return "" messages = build_messages(images, prompt_text) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ) device = next(model.parameters()).device inputs = inputs.to(device) generate_kwargs = {"max_new_tokens": max_new_tokens} if temperature > 0: generate_kwargs["do_sample"] = True generate_kwargs["temperature"] = temperature with torch.inference_mode(): outputs = model.generate( **inputs, **generate_kwargs, ) if "attention_mask" in inputs: input_len = int(inputs["attention_mask"][0].sum().item()) else: input_len = int(inputs["input_ids"].shape[-1]) text = processor.decode( outputs[0][input_len:], skip_special_tokens=True, clean_up_tokenization_spaces=False, ) if torch.cuda.is_available(): torch.cuda.empty_cache() return clean_report_text(text) def load_processor(checkpoint_path: Path, base_model: str): # Match working inference script: always load the full multimodal processor # from the base model, not from LoRA checkpoint folders. return AutoProcessor.from_pretrained(base_model, trust_remote_code=True) def is_completed_report(report_path: Path) -> bool: if not report_path.exists(): return False try: text = report_path.read_text(encoding="utf-8", errors="ignore").strip() except OSError: return False if not text: return False # Treat prior placeholder failures as incomplete so they get regenerated. if text.startswith("": return False return True def load_existing_manifest_ids(manifest_path: Path) -> set: existing_ids = set() if not manifest_path.exists(): return existing_ids try: with manifest_path.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue try: row = json.loads(line) except json.JSONDecodeError: continue study_id = row.get("study_id") if study_id: existing_ids.add(str(study_id)) except OSError: pass return existing_ids def get_study_ids(val_root: Path) -> List[str]: files_dir = val_root / "files" if not files_dir.is_dir(): raise FileNotFoundError(f"Missing directory: {files_dir}") return sorted(path.stem for path in files_dir.glob("*.txt")) def get_study_images(val_root: Path, study_id: str, max_images_per_study: int) -> List[Path]: study_dir = val_root / "images" / study_id if not study_dir.is_dir(): return [] image_paths = sorted( p for p in study_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTENSIONS ) if max_images_per_study > 0: image_paths = image_paths[:max_images_per_study] return image_paths def main() -> None: if not VAL_ROOT.is_dir(): raise FileNotFoundError(f"val_root not found: {VAL_ROOT}") if not CHECKPOINT_PATH.exists(): raise FileNotFoundError(f"checkpoint not found: {CHECKPOINT_PATH}") if not SYSTEM_PROMPT_PATH.exists(): raise FileNotFoundError(f"system prompt file not found: {SYSTEM_PROMPT_PATH}") system_prompt = SYSTEM_PROMPT_PATH.read_text(encoding="utf-8").strip() torch.manual_seed(SEED) print(f"Loading model from {CHECKPOINT_PATH}") model, _ = FastVisionModel.from_pretrained( model_name=str(CHECKPOINT_PATH), load_in_4bit=LOAD_IN_4BIT, load_in_8bit=LOAD_IN_8BIT, ) processor = load_processor(CHECKPOINT_PATH, BASE_MODEL_NAME) FastVisionModel.for_inference(model) out_reports = OUTPUT_DIR / "reports" out_reports.mkdir(parents=True, exist_ok=True) manifest_path = OUTPUT_DIR / "predictions.jsonl" manifest_ids = load_existing_manifest_ids(manifest_path) study_ids = get_study_ids(VAL_ROOT) written = 0 skipped_existing = 0 failed = 0 with manifest_path.open("a", encoding="utf-8") as manifest: for idx, study_id in enumerate(study_ids, start=1): image_paths = get_study_images(VAL_ROOT, study_id, MAX_IMAGES_PER_STUDY) if not image_paths: print(f"[{idx}/{len(study_ids)}] {study_id}: no images, skipped") continue pred_path = out_reports / f"{study_id}.txt" if is_completed_report(pred_path): skipped_existing += 1 if study_id not in manifest_ids: row = { "study_id": study_id, "image_count": len(image_paths), "image_paths": [str(p) for p in image_paths], "prediction_path": str(pred_path), } manifest.write(json.dumps(row, ensure_ascii=False) + "\n") manifest_ids.add(study_id) print(f"[{idx}/{len(study_ids)}] {study_id}: already done, skipping") continue try: pred = generate_report( model=model, processor=processor, image_paths=image_paths, prompt_text=system_prompt, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, ) except torch.OutOfMemoryError as error: failed += 1 print(f"[{idx}/{len(study_ids)}] {study_id}: OOM, skipped ({error})") if torch.cuda.is_available(): torch.cuda.empty_cache() continue except Exception as error: failed += 1 print(f"[{idx}/{len(study_ids)}] {study_id}: failed, skipped ({error})") if torch.cuda.is_available(): torch.cuda.empty_cache() continue pred_path.write_text(pred + "\n", encoding="utf-8") if study_id not in manifest_ids: row = { "study_id": study_id, "image_count": len(image_paths), "image_paths": [str(p) for p in image_paths], "prediction_path": str(pred_path), } manifest.write(json.dumps(row, ensure_ascii=False) + "\n") manifest_ids.add(study_id) written += 1 print(f"[{idx}/{len(study_ids)}] {study_id}: done ({len(image_paths)} image(s))") print("=" * 60) print(f"Newly processed: {written}") print(f"Already completed (skipped): {skipped_existing}") print(f"Failed this run: {failed}") print(f"Predictions dir: {out_reports}") print(f"Manifest: {manifest_path}") if __name__ == "__main__": main()