| from unsloth import FastVisionModel |
| from unsloth.trainer import UnslothVisionDataCollator |
|
|
| import argparse |
| import json |
| import random |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| from datasets import Dataset |
| from PIL import Image, ImageFile, UnidentifiedImageError |
| from trl import SFTConfig, SFTTrainer |
|
|
|
|
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
| DEFAULT_SYSTEM_PROMPT = ( |
| "You are an expert radiology report generator. " |
| "Given one chest X-ray image, write a clinically coherent report in the style of a radiologist. " |
| "Do not hallucinate findings that are not supported by the image. Moreover give resaoning for your findings and highlight the key areas or features in the image that support your findings. " |
| ) |
|
|
| DEFAULT_INSTRUCTION = ( |
| "Analyze this chest X-ray image and generate the corresponding radiology report text." |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Finetune Qwen3-VL 8B with Unsloth on MIMIC-style image/report data." |
| ) |
|
|
| parser.add_argument("--dataset_root", type=str, default="dataset") |
| parser.add_argument("--reports_dir", type=str, default="files") |
| parser.add_argument("--images_glob", type=str, default="images_part_*") |
| parser.add_argument("--model_name", type=str, default="unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit") |
| parser.add_argument("--output_dir", type=str, default="outputs/mimic_qwen3vl_lora") |
|
|
| parser.add_argument("--seed", type=int, default=3407) |
| parser.add_argument("--val_ratio", type=float, default=0.05) |
| parser.add_argument("--max_images_per_study", type=int, default=0, help="0 = use all images per study") |
| parser.add_argument("--max_train_samples", type=int, default=0, help="0 = use all train samples") |
| parser.add_argument("--max_val_samples", type=int, default=0, help="0 = use all val samples") |
| parser.add_argument("--min_report_chars", type=int, default=40) |
| parser.add_argument( |
| "--image_validity_cache", |
| type=str, |
| default="", |
| help="Path to JSON cache for image readability checks. Default: <dataset_root>/.image_validity_cache.json", |
| ) |
| parser.add_argument( |
| "--skip_image_verification", |
| action="store_true", |
| help="Skip pre-verifying image files. Faster startup, but corrupted images may fail at runtime.", |
| ) |
|
|
| parser.add_argument("--instruction", type=str, default=DEFAULT_INSTRUCTION) |
| parser.add_argument( |
| "--system_prompt", |
| type=str, |
| default=DEFAULT_SYSTEM_PROMPT, |
| help="Set to empty string to disable system prompt.", |
| ) |
|
|
| parser.add_argument("--load_in_4bit", action="store_true", default=True) |
| parser.add_argument("--no_4bit", action="store_true", help="Disable 4bit loading") |
|
|
| parser.add_argument("--lora_r", type=int, default=32) |
| parser.add_argument("--lora_alpha", type=int, default=64) |
| parser.add_argument("--lora_dropout", type=float, default=0.01) |
|
|
| parser.add_argument("--batch_size", type=int, default=8) |
| parser.add_argument("--grad_accum", type=int, default=1) |
| parser.add_argument("--learning_rate", type=float, default=2e-5) |
| parser.add_argument("--warmup_steps", type=int, default=100) |
| parser.add_argument("--num_train_epochs", type=float, default=5.0) |
| parser.add_argument("--max_length", type=int, default=2048) |
| parser.add_argument("--logging_steps", type=int, default=10) |
| parser.add_argument("--save_steps", type=int, default=200) |
| parser.add_argument("--dataloader_num_workers", type=int, default=0) |
| parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging") |
| parser.add_argument("--wandb_project", type=str, default="qwen3vl-mimic-finetune") |
| parser.add_argument("--wandb_run_name", type=str, default="") |
| parser.add_argument("--wandb_entity", type=str, default="") |
|
|
| return parser.parse_args() |
|
|
|
|
| def clean_report_text(text: str) -> str: |
| lines = [line.strip() for line in text.splitlines()] |
| non_empty = [line for line in lines if line] |
| return "\n".join(non_empty).strip() |
|
|
|
|
| def get_study_image_paths(dataset_root: Path, images_glob: str, study_id: str) -> List[Path]: |
| image_paths: List[Path] = [] |
| image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} |
|
|
| for images_part in sorted(dataset_root.glob(images_glob)): |
| if not images_part.is_dir(): |
| continue |
| study_dir = images_part / study_id |
| if not study_dir.exists() or not study_dir.is_dir(): |
| continue |
| for image_path in sorted(study_dir.iterdir()): |
| if image_path.suffix.lower() in image_extensions: |
| image_paths.append(image_path) |
|
|
| return image_paths |
|
|
|
|
| def build_samples( |
| dataset_root: Path, |
| reports_dir_name: str, |
| images_glob: str, |
| min_report_chars: int, |
| max_images_per_study: int, |
| ) -> List[Dict[str, str]]: |
| reports_dir = dataset_root / reports_dir_name |
| if not reports_dir.exists(): |
| raise FileNotFoundError(f"Reports folder not found: {reports_dir}") |
|
|
| report_files = sorted(reports_dir.glob("*.txt")) |
| if not report_files: |
| raise FileNotFoundError(f"No .txt reports found in: {reports_dir}") |
|
|
| samples: List[Dict[str, str]] = [] |
|
|
| for report_path in report_files: |
| study_id = report_path.stem |
| report_text = clean_report_text(report_path.read_text(encoding="utf-8", errors="ignore")) |
| if len(report_text) < min_report_chars: |
| continue |
|
|
| image_paths = get_study_image_paths(dataset_root, images_glob, study_id) |
| if not image_paths: |
| continue |
|
|
| if max_images_per_study > 0: |
| image_paths = image_paths[:max_images_per_study] |
|
|
| for image_path in image_paths: |
| samples.append( |
| { |
| "study_id": study_id, |
| "image_path": str(image_path), |
| "report_text": report_text, |
| } |
| ) |
|
|
| if not samples: |
| raise RuntimeError("No valid (image, report) samples were built.") |
|
|
| return samples |
|
|
|
|
| def split_by_study( |
| samples: List[Dict[str, str]], |
| val_ratio: float, |
| seed: int, |
| ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: |
| study_ids = sorted({s["study_id"] for s in samples}) |
| rng = random.Random(seed) |
| rng.shuffle(study_ids) |
|
|
| val_count = max(1, int(len(study_ids) * val_ratio)) if val_ratio > 0 else 0 |
| val_ids = set(study_ids[:val_count]) |
|
|
| train_samples = [s for s in samples if s["study_id"] not in val_ids] |
| val_samples = [s for s in samples if s["study_id"] in val_ids] |
| return train_samples, val_samples |
|
|
|
|
| def _build_messages( |
| image: Image.Image, |
| report_text: str, |
| instruction: str, |
| system_prompt: Optional[str], |
| ) -> Dict[str, List[Dict]]: |
| messages: List[Dict] = [] |
| if system_prompt: |
| messages.append( |
| { |
| "role": "system", |
| "content": [{"type": "text", "text": system_prompt}], |
| } |
| ) |
|
|
| messages.extend( |
| [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": instruction}, |
| {"type": "image", "image": image}, |
| ], |
| }, |
| { |
| "role": "assistant", |
| "content": [{"type": "text", "text": report_text}], |
| }, |
| ] |
| ) |
|
|
| return {"messages": messages} |
|
|
|
|
| def load_image_validity_cache(cache_path: Path) -> Dict[str, bool]: |
| if not cache_path.exists(): |
| return {} |
| try: |
| data = json.loads(cache_path.read_text(encoding="utf-8")) |
| except (OSError, ValueError, json.JSONDecodeError): |
| return {} |
| if not isinstance(data, dict): |
| return {} |
| return {str(key): bool(value) for key, value in data.items()} |
|
|
|
|
| def save_image_validity_cache(cache_path: Path, cache: Dict[str, bool]) -> None: |
| cache_path.parent.mkdir(parents=True, exist_ok=True) |
| cache_path.write_text(json.dumps(cache), encoding="utf-8") |
|
|
|
|
| def filter_readable_samples( |
| samples: List[Dict[str, str]], |
| cache_path: Path, |
| split_name: str, |
| ) -> List[Dict[str, str]]: |
| cache = load_image_validity_cache(cache_path) |
|
|
| filtered: List[Dict[str, str]] = [] |
| skipped = 0 |
| newly_checked = 0 |
|
|
| for sample in samples: |
| image_path = sample["image_path"] |
| is_valid = cache.get(image_path) |
|
|
| if is_valid is None: |
| newly_checked += 1 |
| try: |
| with Image.open(image_path) as opened_image: |
| opened_image.verify() |
| is_valid = True |
| except (OSError, UnidentifiedImageError, ValueError): |
| is_valid = False |
| cache[image_path] = is_valid |
|
|
| if is_valid: |
| filtered.append(sample) |
| else: |
| skipped += 1 |
|
|
| save_image_validity_cache(cache_path, cache) |
| print( |
| f"[{split_name}] Kept {len(filtered)} / {len(samples)} samples, skipped {skipped} corrupt images " |
| f"(newly checked: {newly_checked})." |
| ) |
| return filtered |
|
|
|
|
| def build_hf_dataset( |
| samples: List[Dict[str, str]], |
| ) -> Dataset: |
| rows = [ |
| { |
| "image_path": sample["image_path"], |
| "report_text": sample["report_text"], |
| } |
| for sample in samples |
| ] |
| return Dataset.from_list(rows) |
|
|
|
|
| def attach_lazy_vision_transform( |
| dataset: Dataset, |
| instruction: str, |
| system_prompt: Optional[str], |
| split_name: str, |
| ) -> Dataset: |
| skipped = {"count": 0} |
|
|
| def transform(examples: Dict[str, List[str] | str]) -> Dict[str, List[Dict]]: |
| image_paths = examples["image_path"] |
| report_texts = examples["report_text"] |
| is_batch = isinstance(image_paths, list) |
|
|
| if not is_batch: |
| image_paths = [image_paths] |
| report_texts = [report_texts] |
|
|
| messages_batch: List[List[Dict]] = [] |
| for image_path, report_text in zip(image_paths, report_texts): |
| try: |
| with Image.open(str(image_path)) as opened_image: |
| image = opened_image.convert("RGB") |
| except (OSError, UnidentifiedImageError, ValueError) as error: |
| skipped["count"] += 1 |
| if skipped["count"] <= 5: |
| print(f"[{split_name}] Runtime unreadable image: {image_path} ({error})") |
| image = Image.new("RGB", (224, 224), color=(0, 0, 0)) |
|
|
| messages_batch.append( |
| _build_messages( |
| image=image, |
| report_text=str(report_text), |
| instruction=instruction, |
| system_prompt=system_prompt, |
| )["messages"] |
| ) |
|
|
| if is_batch: |
| return {"messages": messages_batch} |
| return {"messages": messages_batch[0]} |
|
|
| dataset.set_transform(transform) |
| return dataset |
|
|
|
|
| def print_gpu_memory_stats(prefix: str) -> None: |
| if not torch.cuda.is_available(): |
| print(f"[{prefix}] CUDA not available.") |
| return |
|
|
| gpu_stats = torch.cuda.get_device_properties(0) |
| max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) |
| used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) |
| print(f"[{prefix}] GPU = {gpu_stats.name}") |
| print(f"[{prefix}] Max GPU memory = {max_memory} GB") |
| print(f"[{prefix}] Reserved memory = {used_memory} GB") |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| if args.no_4bit: |
| args.load_in_4bit = False |
|
|
| if args.val_ratio < 0 or args.val_ratio >= 1: |
| raise ValueError("--val_ratio must be in [0, 1).") |
|
|
| if args.num_train_epochs <= 0: |
| raise ValueError("--num_train_epochs must be > 0.") |
|
|
| if args.use_wandb: |
| import os |
|
|
| os.environ["WANDB_PROJECT"] = args.wandb_project |
| if args.wandb_run_name: |
| os.environ["WANDB_NAME"] = args.wandb_run_name |
| if args.wandb_entity: |
| os.environ["WANDB_ENTITY"] = args.wandb_entity |
|
|
| print(f"Using epoch-based training for {args.num_train_epochs} epochs.") |
|
|
| random.seed(args.seed) |
| torch.manual_seed(args.seed) |
|
|
| dataset_root = Path(args.dataset_root) |
| if not dataset_root.exists(): |
| raise FileNotFoundError(f"Dataset root not found: {dataset_root}") |
|
|
| print("Loading model...") |
| model, tokenizer = FastVisionModel.from_pretrained( |
| args.model_name, |
| load_in_4bit=args.load_in_4bit, |
| use_gradient_checkpointing="unsloth", |
| ) |
|
|
| model = FastVisionModel.get_peft_model( |
| model, |
| finetune_vision_layers=True, |
| finetune_language_layers=True, |
| finetune_attention_modules=True, |
| finetune_mlp_modules=True, |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| bias="none", |
| random_state=args.seed, |
| use_rslora=False, |
| loftq_config=None, |
| ) |
|
|
| print("Building paired image-report samples...") |
| samples = build_samples( |
| dataset_root=dataset_root, |
| reports_dir_name=args.reports_dir, |
| images_glob=args.images_glob, |
| min_report_chars=args.min_report_chars, |
| max_images_per_study=args.max_images_per_study, |
| ) |
|
|
| train_samples, val_samples = split_by_study(samples, args.val_ratio, args.seed) |
|
|
| if args.max_train_samples > 0: |
| train_samples = train_samples[: args.max_train_samples] |
| if args.max_val_samples > 0: |
| val_samples = val_samples[: args.max_val_samples] |
|
|
| if args.skip_image_verification: |
| print("Skipping image verification step as requested.") |
| else: |
| cache_path = ( |
| Path(args.image_validity_cache) |
| if args.image_validity_cache |
| else dataset_root / ".image_validity_cache.json" |
| ) |
| print(f"Verifying image readability with cache: {cache_path}") |
| train_samples = filter_readable_samples(train_samples, cache_path, split_name="train") |
| if val_samples: |
| val_samples = filter_readable_samples(val_samples, cache_path, split_name="val") |
|
|
| print(f"Total samples: {len(samples)}") |
| print(f"Train samples: {len(train_samples)}") |
| print(f"Val samples: {len(val_samples)}") |
|
|
| system_prompt = args.system_prompt.strip() if args.system_prompt else "" |
| if not system_prompt: |
| system_prompt = None |
|
|
| train_dataset = build_hf_dataset(train_samples) |
| train_dataset = attach_lazy_vision_transform(train_dataset, args.instruction, system_prompt, split_name="train") |
| eval_dataset = ( |
| attach_lazy_vision_transform(build_hf_dataset(val_samples), args.instruction, system_prompt, split_name="val") |
| if val_samples |
| else None |
| ) |
|
|
| print(f"Final train dataset size: {len(train_dataset)}") |
| if eval_dataset is not None: |
| print(f"Final val dataset size: {len(eval_dataset)}") |
|
|
| FastVisionModel.for_training(model) |
|
|
| config_kwargs = { |
| "per_device_train_batch_size": args.batch_size, |
| "gradient_accumulation_steps": args.grad_accum, |
| "warmup_steps": args.warmup_steps, |
| "learning_rate": args.learning_rate, |
| "logging_steps": args.logging_steps, |
| "optim": "adamw_8bit", |
| "weight_decay": 0.001, |
| "lr_scheduler_type": "linear", |
| "seed": args.seed, |
| "output_dir": args.output_dir, |
| "report_to": "wandb" if args.use_wandb else "none", |
| "save_steps": args.save_steps, |
| "remove_unused_columns": False, |
| "dataset_text_field": "", |
| "dataset_kwargs": {"skip_prepare_dataset": True}, |
| "max_length": args.max_length, |
| "num_train_epochs": args.num_train_epochs, |
| "dataloader_num_workers": args.dataloader_num_workers, |
| } |
|
|
| trainer = SFTTrainer( |
| model=model, |
| tokenizer=tokenizer, |
| data_collator=UnslothVisionDataCollator(model, tokenizer), |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| args=SFTConfig(**config_kwargs), |
| ) |
|
|
| print_gpu_memory_stats("BEFORE TRAIN") |
| trainer_stats = trainer.train() |
| print_gpu_memory_stats("AFTER TRAIN") |
|
|
| print("Train metrics:") |
| print(trainer_stats.metrics) |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| model.save_pretrained(str(output_dir)) |
| tokenizer.save_pretrained(str(output_dir)) |
|
|
| print(f"Saved LoRA adapter + tokenizer to: {output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |