from unsloth import FastVisionModel from unsloth.trainer import UnslothVisionDataCollator import argparse import json import random from pathlib import Path from typing import Dict, List, Optional, Tuple import torch from datasets import Dataset from PIL import Image, ImageFile, UnidentifiedImageError from trl import SFTConfig, SFTTrainer ImageFile.LOAD_TRUNCATED_IMAGES = True DEFAULT_SYSTEM_PROMPT = ( "You are an expert radiology report generator. " "Given one chest X-ray image, write a clinically coherent report in the style of a radiologist. " "Do not hallucinate findings that are not supported by the image. Moreover give resaoning for your findings and highlight the key areas or features in the image that support your findings. " ) DEFAULT_INSTRUCTION = ( "Analyze this chest X-ray image and generate the corresponding radiology report text." ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Finetune Qwen3-VL 8B with Unsloth on MIMIC-style image/report data." ) parser.add_argument("--dataset_root", type=str, default="dataset") parser.add_argument("--reports_dir", type=str, default="files") parser.add_argument("--images_glob", type=str, default="images_part_*") parser.add_argument("--model_name", type=str, default="unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit") parser.add_argument("--output_dir", type=str, default="outputs/mimic_qwen3vl_lora") parser.add_argument("--seed", type=int, default=3407) parser.add_argument("--val_ratio", type=float, default=0.05) parser.add_argument("--max_images_per_study", type=int, default=0, help="0 = use all images per study") parser.add_argument("--max_train_samples", type=int, default=0, help="0 = use all train samples") parser.add_argument("--max_val_samples", type=int, default=0, help="0 = use all val samples") parser.add_argument("--min_report_chars", type=int, default=40) parser.add_argument( "--image_validity_cache", type=str, default="", help="Path to JSON cache for image readability checks. Default: /.image_validity_cache.json", ) parser.add_argument( "--skip_image_verification", action="store_true", help="Skip pre-verifying image files. Faster startup, but corrupted images may fail at runtime.", ) parser.add_argument("--instruction", type=str, default=DEFAULT_INSTRUCTION) parser.add_argument( "--system_prompt", type=str, default=DEFAULT_SYSTEM_PROMPT, help="Set to empty string to disable system prompt.", ) parser.add_argument("--load_in_4bit", action="store_true", default=True) parser.add_argument("--no_4bit", action="store_true", help="Disable 4bit loading") parser.add_argument("--lora_r", type=int, default=32) parser.add_argument("--lora_alpha", type=int, default=64) parser.add_argument("--lora_dropout", type=float, default=0.01) parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--grad_accum", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=2e-5) parser.add_argument("--warmup_steps", type=int, default=100) parser.add_argument("--num_train_epochs", type=float, default=5.0) parser.add_argument("--max_length", type=int, default=2048) parser.add_argument("--logging_steps", type=int, default=10) parser.add_argument("--save_steps", type=int, default=200) parser.add_argument("--dataloader_num_workers", type=int, default=0) parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging") parser.add_argument("--wandb_project", type=str, default="qwen3vl-mimic-finetune") parser.add_argument("--wandb_run_name", type=str, default="") parser.add_argument("--wandb_entity", type=str, default="") return parser.parse_args() def clean_report_text(text: str) -> str: lines = [line.strip() for line in text.splitlines()] non_empty = [line for line in lines if line] return "\n".join(non_empty).strip() def get_study_image_paths(dataset_root: Path, images_glob: str, study_id: str) -> List[Path]: image_paths: List[Path] = [] image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} for images_part in sorted(dataset_root.glob(images_glob)): if not images_part.is_dir(): continue study_dir = images_part / study_id if not study_dir.exists() or not study_dir.is_dir(): continue for image_path in sorted(study_dir.iterdir()): if image_path.suffix.lower() in image_extensions: image_paths.append(image_path) return image_paths def build_samples( dataset_root: Path, reports_dir_name: str, images_glob: str, min_report_chars: int, max_images_per_study: int, ) -> List[Dict[str, str]]: reports_dir = dataset_root / reports_dir_name if not reports_dir.exists(): raise FileNotFoundError(f"Reports folder not found: {reports_dir}") report_files = sorted(reports_dir.glob("*.txt")) if not report_files: raise FileNotFoundError(f"No .txt reports found in: {reports_dir}") samples: List[Dict[str, str]] = [] for report_path in report_files: study_id = report_path.stem report_text = clean_report_text(report_path.read_text(encoding="utf-8", errors="ignore")) if len(report_text) < min_report_chars: continue image_paths = get_study_image_paths(dataset_root, images_glob, study_id) if not image_paths: continue if max_images_per_study > 0: image_paths = image_paths[:max_images_per_study] for image_path in image_paths: samples.append( { "study_id": study_id, "image_path": str(image_path), "report_text": report_text, } ) if not samples: raise RuntimeError("No valid (image, report) samples were built.") return samples def split_by_study( samples: List[Dict[str, str]], val_ratio: float, seed: int, ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: study_ids = sorted({s["study_id"] for s in samples}) rng = random.Random(seed) rng.shuffle(study_ids) val_count = max(1, int(len(study_ids) * val_ratio)) if val_ratio > 0 else 0 val_ids = set(study_ids[:val_count]) train_samples = [s for s in samples if s["study_id"] not in val_ids] val_samples = [s for s in samples if s["study_id"] in val_ids] return train_samples, val_samples def _build_messages( image: Image.Image, report_text: str, instruction: str, system_prompt: Optional[str], ) -> Dict[str, List[Dict]]: messages: List[Dict] = [] if system_prompt: messages.append( { "role": "system", "content": [{"type": "text", "text": system_prompt}], } ) messages.extend( [ { "role": "user", "content": [ {"type": "text", "text": instruction}, {"type": "image", "image": image}, ], }, { "role": "assistant", "content": [{"type": "text", "text": report_text}], }, ] ) return {"messages": messages} def load_image_validity_cache(cache_path: Path) -> Dict[str, bool]: if not cache_path.exists(): return {} try: data = json.loads(cache_path.read_text(encoding="utf-8")) except (OSError, ValueError, json.JSONDecodeError): return {} if not isinstance(data, dict): return {} return {str(key): bool(value) for key, value in data.items()} def save_image_validity_cache(cache_path: Path, cache: Dict[str, bool]) -> None: cache_path.parent.mkdir(parents=True, exist_ok=True) cache_path.write_text(json.dumps(cache), encoding="utf-8") def filter_readable_samples( samples: List[Dict[str, str]], cache_path: Path, split_name: str, ) -> List[Dict[str, str]]: cache = load_image_validity_cache(cache_path) filtered: List[Dict[str, str]] = [] skipped = 0 newly_checked = 0 for sample in samples: image_path = sample["image_path"] is_valid = cache.get(image_path) if is_valid is None: newly_checked += 1 try: with Image.open(image_path) as opened_image: opened_image.verify() is_valid = True except (OSError, UnidentifiedImageError, ValueError): is_valid = False cache[image_path] = is_valid if is_valid: filtered.append(sample) else: skipped += 1 save_image_validity_cache(cache_path, cache) print( f"[{split_name}] Kept {len(filtered)} / {len(samples)} samples, skipped {skipped} corrupt images " f"(newly checked: {newly_checked})." ) return filtered def build_hf_dataset( samples: List[Dict[str, str]], ) -> Dataset: rows = [ { "image_path": sample["image_path"], "report_text": sample["report_text"], } for sample in samples ] return Dataset.from_list(rows) def attach_lazy_vision_transform( dataset: Dataset, instruction: str, system_prompt: Optional[str], split_name: str, ) -> Dataset: skipped = {"count": 0} def transform(examples: Dict[str, List[str] | str]) -> Dict[str, List[Dict]]: image_paths = examples["image_path"] report_texts = examples["report_text"] is_batch = isinstance(image_paths, list) if not is_batch: image_paths = [image_paths] report_texts = [report_texts] messages_batch: List[List[Dict]] = [] for image_path, report_text in zip(image_paths, report_texts): try: with Image.open(str(image_path)) as opened_image: image = opened_image.convert("RGB") except (OSError, UnidentifiedImageError, ValueError) as error: skipped["count"] += 1 if skipped["count"] <= 5: print(f"[{split_name}] Runtime unreadable image: {image_path} ({error})") image = Image.new("RGB", (224, 224), color=(0, 0, 0)) messages_batch.append( _build_messages( image=image, report_text=str(report_text), instruction=instruction, system_prompt=system_prompt, )["messages"] ) if is_batch: return {"messages": messages_batch} return {"messages": messages_batch[0]} dataset.set_transform(transform) return dataset def print_gpu_memory_stats(prefix: str) -> None: if not torch.cuda.is_available(): print(f"[{prefix}] CUDA not available.") return gpu_stats = torch.cuda.get_device_properties(0) max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) print(f"[{prefix}] GPU = {gpu_stats.name}") print(f"[{prefix}] Max GPU memory = {max_memory} GB") print(f"[{prefix}] Reserved memory = {used_memory} GB") def main() -> None: args = parse_args() if args.no_4bit: args.load_in_4bit = False if args.val_ratio < 0 or args.val_ratio >= 1: raise ValueError("--val_ratio must be in [0, 1).") if args.num_train_epochs <= 0: raise ValueError("--num_train_epochs must be > 0.") if args.use_wandb: import os os.environ["WANDB_PROJECT"] = args.wandb_project if args.wandb_run_name: os.environ["WANDB_NAME"] = args.wandb_run_name if args.wandb_entity: os.environ["WANDB_ENTITY"] = args.wandb_entity print(f"Using epoch-based training for {args.num_train_epochs} epochs.") random.seed(args.seed) torch.manual_seed(args.seed) dataset_root = Path(args.dataset_root) if not dataset_root.exists(): raise FileNotFoundError(f"Dataset root not found: {dataset_root}") print("Loading model...") model, tokenizer = FastVisionModel.from_pretrained( args.model_name, load_in_4bit=args.load_in_4bit, use_gradient_checkpointing="unsloth", ) model = FastVisionModel.get_peft_model( model, finetune_vision_layers=True, finetune_language_layers=True, finetune_attention_modules=True, finetune_mlp_modules=True, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", random_state=args.seed, use_rslora=False, loftq_config=None, ) print("Building paired image-report samples...") samples = build_samples( dataset_root=dataset_root, reports_dir_name=args.reports_dir, images_glob=args.images_glob, min_report_chars=args.min_report_chars, max_images_per_study=args.max_images_per_study, ) train_samples, val_samples = split_by_study(samples, args.val_ratio, args.seed) if args.max_train_samples > 0: train_samples = train_samples[: args.max_train_samples] if args.max_val_samples > 0: val_samples = val_samples[: args.max_val_samples] if args.skip_image_verification: print("Skipping image verification step as requested.") else: cache_path = ( Path(args.image_validity_cache) if args.image_validity_cache else dataset_root / ".image_validity_cache.json" ) print(f"Verifying image readability with cache: {cache_path}") train_samples = filter_readable_samples(train_samples, cache_path, split_name="train") if val_samples: val_samples = filter_readable_samples(val_samples, cache_path, split_name="val") print(f"Total samples: {len(samples)}") print(f"Train samples: {len(train_samples)}") print(f"Val samples: {len(val_samples)}") system_prompt = args.system_prompt.strip() if args.system_prompt else "" if not system_prompt: system_prompt = None train_dataset = build_hf_dataset(train_samples) train_dataset = attach_lazy_vision_transform(train_dataset, args.instruction, system_prompt, split_name="train") eval_dataset = ( attach_lazy_vision_transform(build_hf_dataset(val_samples), args.instruction, system_prompt, split_name="val") if val_samples else None ) print(f"Final train dataset size: {len(train_dataset)}") if eval_dataset is not None: print(f"Final val dataset size: {len(eval_dataset)}") FastVisionModel.for_training(model) config_kwargs = { "per_device_train_batch_size": args.batch_size, "gradient_accumulation_steps": args.grad_accum, "warmup_steps": args.warmup_steps, "learning_rate": args.learning_rate, "logging_steps": args.logging_steps, "optim": "adamw_8bit", "weight_decay": 0.001, "lr_scheduler_type": "linear", "seed": args.seed, "output_dir": args.output_dir, "report_to": "wandb" if args.use_wandb else "none", "save_steps": args.save_steps, "remove_unused_columns": False, "dataset_text_field": "", "dataset_kwargs": {"skip_prepare_dataset": True}, "max_length": args.max_length, "num_train_epochs": args.num_train_epochs, "dataloader_num_workers": args.dataloader_num_workers, } trainer = SFTTrainer( model=model, tokenizer=tokenizer, data_collator=UnslothVisionDataCollator(model, tokenizer), train_dataset=train_dataset, eval_dataset=eval_dataset, args=SFTConfig(**config_kwargs), ) print_gpu_memory_stats("BEFORE TRAIN") trainer_stats = trainer.train() print_gpu_memory_stats("AFTER TRAIN") print("Train metrics:") print(trainer_stats.metrics) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(str(output_dir)) tokenizer.save_pretrained(str(output_dir)) print(f"Saved LoRA adapter + tokenizer to: {output_dir}") if __name__ == "__main__": main()