| import os |
| os.environ["CUDA_AVAILABLE_DEVICES"] = "6" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "6" |
|
|
| from unsloth import FastVisionModel |
|
|
| import json |
| from pathlib import Path |
| from typing import Dict, List, Optional |
|
|
| import torch |
| from PIL import Image |
| from transformers import AutoProcessor |
|
|
| CHECKPOINT_PATH = Path("outputs/mimic_qwen3vl_lora_8bit_5/checkpoint-17454") |
| BASE_MODEL_NAME = "unsloth/Qwen3-VL-8B-Thinking" |
| SYSTEM_PROMPT_PATH = Path(__file__).with_name("new_system_prompt.txt") |
| VAL_ROOT = Path("dataset/val") |
| OUTPUT_DIR = Path("dataset/val_outputs") |
| MAX_NEW_TOKENS = 4096 |
| TEMPERATURE = 0.0 |
| SEED = 3407 |
| MAX_IMAGES_PER_STUDY = 0 |
| LOAD_IN_4BIT = False |
| LOAD_IN_8BIT = True |
|
|
| IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"} |
|
|
|
|
| def clean_text(text: str) -> str: |
| lines = [line.strip() for line in text.splitlines() if line.strip()] |
| return "\n".join(lines).strip() |
|
|
|
|
| def extract_findings_impression(text: str) -> Optional[str]: |
| import re |
|
|
| text = text.replace("\r\n", "\n").replace("\r", "\n") |
| boundary = re.compile(r"^[ \t]*([A-Za-z][A-Za-z ()/]{3,}):[ \t]*", re.MULTILINE) |
|
|
| bounds: list = [] |
| for m in boundary.finditer(text): |
| name = m.group(1).strip().upper() |
| if "PROVISIONAL" in name: |
| continue |
| bounds.append((name, m.start(), m.end())) |
|
|
| found: dict = {} |
| for i, (name, _header_start, content_start) in enumerate(bounds): |
| if name == "FINDINGS": |
| key = "Findings" |
| elif name == "IMPRESSION": |
| key = "Impression" |
| else: |
| continue |
|
|
| if key in found: |
| continue |
|
|
| content_end = bounds[i + 1][1] if i + 1 < len(bounds) else len(text) |
| raw = text[content_start:content_end] |
| lines = [ |
| line.strip() |
| for line in raw.splitlines() |
| if line.strip() and not re.match(r"^[_\-=]{5,}$", line.strip()) |
| ] |
| content = "\n".join(lines) |
| if content: |
| found[key] = content |
|
|
| if not found: |
| return None |
|
|
| parts: list = [] |
| for heading in ("Findings", "Impression"): |
| if heading in found: |
| parts.append(f"{heading}:\n{found[heading]}") |
|
|
| return "\n\n".join(parts) if parts else None |
|
|
|
|
| def clean_report_text(text: str) -> str: |
| extracted = extract_findings_impression(text) |
| if extracted: |
| return extracted.strip() |
| return clean_text(text) |
|
|
|
|
| def load_images(image_paths: List[Path]) -> List[Image.Image]: |
| images: List[Image.Image] = [] |
| for image_path in image_paths: |
| try: |
| with Image.open(image_path) as img: |
| images.append(img.convert("RGB")) |
| except Exception as error: |
| print(f"[warn] Skipping unreadable image {image_path}: {error}") |
| return images |
|
|
|
|
| def build_messages(images: List[Image.Image], prompt_text: str) -> List[Dict]: |
| |
| |
| user_content: List[Dict] = [] |
| user_content.extend({"type": "image", "image": image} for image in images) |
| user_content.append({"type": "text", "text": prompt_text}) |
| return [{"role": "user", "content": user_content}] |
|
|
|
|
| def generate_report( |
| model, |
| processor, |
| image_paths: List[Path], |
| prompt_text: str, |
| max_new_tokens: int, |
| temperature: float, |
| ) -> str: |
| images = load_images(image_paths) |
| if not images: |
| return "<no readable images>" |
|
|
| messages = build_messages(images, prompt_text) |
| inputs = processor.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| ) |
| device = next(model.parameters()).device |
| inputs = inputs.to(device) |
|
|
| generate_kwargs = {"max_new_tokens": max_new_tokens} |
| if temperature > 0: |
| generate_kwargs["do_sample"] = True |
| generate_kwargs["temperature"] = temperature |
|
|
| with torch.inference_mode(): |
| outputs = model.generate( |
| **inputs, |
| **generate_kwargs, |
| ) |
|
|
| if "attention_mask" in inputs: |
| input_len = int(inputs["attention_mask"][0].sum().item()) |
| else: |
| input_len = int(inputs["input_ids"].shape[-1]) |
|
|
| text = processor.decode( |
| outputs[0][input_len:], |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False, |
| ) |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return clean_report_text(text) |
|
|
|
|
| def load_processor(checkpoint_path: Path, base_model: str): |
| |
| |
| return AutoProcessor.from_pretrained(base_model, trust_remote_code=True) |
|
|
|
|
| def is_completed_report(report_path: Path) -> bool: |
| if not report_path.exists(): |
| return False |
| try: |
| text = report_path.read_text(encoding="utf-8", errors="ignore").strip() |
| except OSError: |
| return False |
| if not text: |
| return False |
| |
| if text.startswith("<tokenization failed:") or text == "<no readable images>": |
| return False |
| return True |
|
|
|
|
| def load_existing_manifest_ids(manifest_path: Path) -> set: |
| existing_ids = set() |
| if not manifest_path.exists(): |
| return existing_ids |
| try: |
| with manifest_path.open("r", encoding="utf-8") as handle: |
| for line in handle: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| row = json.loads(line) |
| except json.JSONDecodeError: |
| continue |
| study_id = row.get("study_id") |
| if study_id: |
| existing_ids.add(str(study_id)) |
| except OSError: |
| pass |
| return existing_ids |
|
|
|
|
| def get_study_ids(val_root: Path) -> List[str]: |
| files_dir = val_root / "files" |
| if not files_dir.is_dir(): |
| raise FileNotFoundError(f"Missing directory: {files_dir}") |
| return sorted(path.stem for path in files_dir.glob("*.txt")) |
|
|
|
|
| def get_study_images(val_root: Path, study_id: str, max_images_per_study: int) -> List[Path]: |
| study_dir = val_root / "images" / study_id |
| if not study_dir.is_dir(): |
| return [] |
| image_paths = sorted( |
| p for p in study_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTENSIONS |
| ) |
| if max_images_per_study > 0: |
| image_paths = image_paths[:max_images_per_study] |
| return image_paths |
|
|
|
|
| def main() -> None: |
| if not VAL_ROOT.is_dir(): |
| raise FileNotFoundError(f"val_root not found: {VAL_ROOT}") |
| if not CHECKPOINT_PATH.exists(): |
| raise FileNotFoundError(f"checkpoint not found: {CHECKPOINT_PATH}") |
| if not SYSTEM_PROMPT_PATH.exists(): |
| raise FileNotFoundError(f"system prompt file not found: {SYSTEM_PROMPT_PATH}") |
|
|
| system_prompt = SYSTEM_PROMPT_PATH.read_text(encoding="utf-8").strip() |
|
|
| torch.manual_seed(SEED) |
|
|
| print(f"Loading model from {CHECKPOINT_PATH}") |
| model, _ = FastVisionModel.from_pretrained( |
| model_name=str(CHECKPOINT_PATH), |
| load_in_4bit=LOAD_IN_4BIT, |
| load_in_8bit=LOAD_IN_8BIT, |
| ) |
| processor = load_processor(CHECKPOINT_PATH, BASE_MODEL_NAME) |
| FastVisionModel.for_inference(model) |
|
|
| out_reports = OUTPUT_DIR / "reports" |
| out_reports.mkdir(parents=True, exist_ok=True) |
| manifest_path = OUTPUT_DIR / "predictions.jsonl" |
| manifest_ids = load_existing_manifest_ids(manifest_path) |
|
|
| study_ids = get_study_ids(VAL_ROOT) |
|
|
| written = 0 |
| skipped_existing = 0 |
| failed = 0 |
| with manifest_path.open("a", encoding="utf-8") as manifest: |
| for idx, study_id in enumerate(study_ids, start=1): |
| image_paths = get_study_images(VAL_ROOT, study_id, MAX_IMAGES_PER_STUDY) |
| if not image_paths: |
| print(f"[{idx}/{len(study_ids)}] {study_id}: no images, skipped") |
| continue |
|
|
| pred_path = out_reports / f"{study_id}.txt" |
| if is_completed_report(pred_path): |
| skipped_existing += 1 |
| if study_id not in manifest_ids: |
| row = { |
| "study_id": study_id, |
| "image_count": len(image_paths), |
| "image_paths": [str(p) for p in image_paths], |
| "prediction_path": str(pred_path), |
| } |
| manifest.write(json.dumps(row, ensure_ascii=False) + "\n") |
| manifest_ids.add(study_id) |
| print(f"[{idx}/{len(study_ids)}] {study_id}: already done, skipping") |
| continue |
|
|
| try: |
| pred = generate_report( |
| model=model, |
| processor=processor, |
| image_paths=image_paths, |
| prompt_text=system_prompt, |
| max_new_tokens=MAX_NEW_TOKENS, |
| temperature=TEMPERATURE, |
| ) |
| except torch.OutOfMemoryError as error: |
| failed += 1 |
| print(f"[{idx}/{len(study_ids)}] {study_id}: OOM, skipped ({error})") |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| continue |
| except Exception as error: |
| failed += 1 |
| print(f"[{idx}/{len(study_ids)}] {study_id}: failed, skipped ({error})") |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| continue |
|
|
| pred_path.write_text(pred + "\n", encoding="utf-8") |
|
|
| if study_id not in manifest_ids: |
| row = { |
| "study_id": study_id, |
| "image_count": len(image_paths), |
| "image_paths": [str(p) for p in image_paths], |
| "prediction_path": str(pred_path), |
| } |
| manifest.write(json.dumps(row, ensure_ascii=False) + "\n") |
| manifest_ids.add(study_id) |
|
|
| written += 1 |
| print(f"[{idx}/{len(study_ids)}] {study_id}: done ({len(image_paths)} image(s))") |
|
|
| print("=" * 60) |
| print(f"Newly processed: {written}") |
| print(f"Already completed (skipped): {skipped_existing}") |
| print(f"Failed this run: {failed}") |
| print(f"Predictions dir: {out_reports}") |
| print(f"Manifest: {manifest_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|