| import os |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "6") |
| os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1") |
|
|
| from unsloth import FastVisionModel |
| from unsloth.trainer import UnslothVisionDataCollator |
|
|
| import argparse |
| import csv |
| import json |
| import random |
| import time |
| import traceback |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| from datasets import Dataset |
| from PIL import Image, ImageFile, UnidentifiedImageError |
| from trl import SFTConfig, SFTTrainer |
| from transformers import TrainerCallback |
|
|
|
|
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
| DEFAULT_SYSTEM_PROMPT = ( |
| "You are an expert radiology report generator. " |
| "Given one chest X-ray image, write a clinically coherent report in the style of a radiologist. " |
| "Do not hallucinate findings that are not supported by the image. Moreover give resaoning for your findings and highlight the key areas or features in the image that support your findings. " |
| "Include concise clinical reasoning for each key finding and explain why the visual evidence supports your conclusion. " |
| ) |
|
|
| DEFAULT_INSTRUCTION = ( |
| "Analyze this chest X-ray image and generate the corresponding radiology report text with concise reasoning for why each key finding is present or absent." |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Finetune Qwen3-VL 8B with Unsloth on MIMIC-style image/report data." |
| ) |
|
|
| parser.add_argument("--dataset_root", type=str, default="dataset") |
| parser.add_argument("--reports_dir", type=str, default="files") |
| parser.add_argument("--images_glob", type=str, default="images_*") |
| parser.add_argument("--model_name", type=str, default="unsloth/Qwen3-VL-8B-Thinking") |
| parser.add_argument("--output_dir", type=str, default="outputs/mimic_qwen3vl_lora_8bit_5") |
|
|
| parser.add_argument("--seed", type=int, default=3407) |
| parser.add_argument("--val_ratio", type=float, default=0.01) |
| parser.add_argument("--max_images_per_study", type=int, default=0, help="0 = use all images per study") |
| parser.add_argument("--max_train_samples", type=int, default=0, help="0 = use all train samples") |
| parser.add_argument("--max_val_samples", type=int, default=0, help="0 = use all val samples") |
| parser.add_argument("--min_report_chars", type=int, default=40) |
| parser.add_argument( |
| "--image_validity_cache", |
| type=str, |
| default="", |
| help="Path to JSON cache for image readability checks. Default: <dataset_root>/.image_validity_cache.json", |
| ) |
| parser.add_argument( |
| "--skip_image_verification", |
| action="store_true", |
| help="Skip pre-verifying image files. Faster startup, but corrupted images may fail at runtime.", |
| ) |
|
|
| parser.add_argument("--instruction", type=str, default=DEFAULT_INSTRUCTION) |
| parser.add_argument( |
| "--system_prompt", |
| type=str, |
| default=DEFAULT_SYSTEM_PROMPT, |
| help="Set to empty string to disable system prompt.", |
| ) |
|
|
| parser.add_argument("--load_in_8bit", action="store_true", default=True) |
| parser.add_argument("--no_8bit", action="store_true", help="Disable 8bit loading") |
| parser.add_argument( |
| "--cuda_device", |
| type=int, |
| default=0, |
| help="Visible CUDA device index to use for strict 8-bit training (default 0 when CUDA_VISIBLE_DEVICES is set).", |
| ) |
|
|
| parser.add_argument("--lora_r", type=int, default=32) |
| parser.add_argument("--lora_alpha", type=int, default=64) |
| parser.add_argument("--lora_dropout", type=float, default=0.01) |
|
|
| parser.add_argument("--batch_size", type=int, default=3) |
| parser.add_argument("--grad_accum", type=int, default=2) |
| parser.add_argument("--max_grad_norm", type=float, default=1.0) |
| parser.add_argument("--learning_rate", type=float, default=3e-5) |
| parser.add_argument("--warmup_steps", type=int, default=750) |
| parser.add_argument("--num_train_epochs", type=float, default=3.0) |
| parser.add_argument("--max_length", type=int, default=4096) |
| parser.add_argument("--logging_steps", type=int, default=10) |
| parser.add_argument("--save_steps", type=int, default=1000) |
| parser.add_argument("--max_train_retries", type=int, default=5) |
| parser.add_argument("--retry_wait_seconds", type=int, default=15) |
| parser.add_argument("--dataloader_num_workers", type=int, default=0) |
| parser.add_argument( |
| "--resume_from_checkpoint", |
| type=str, |
| default="", |
| help="Optional checkpoint path to resume from explicitly.", |
| ) |
| parser.add_argument( |
| "--eval_sample_count", |
| type=int, |
| default=1, |
| help="Random validation generations to print each eval (default 3). Set 0 to disable.", |
| ) |
| parser.add_argument( |
| "--eval_max_new_tokens", |
| type=int, |
| default=96, |
| help="Max new tokens for eval sample generation callback.", |
| ) |
| parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging") |
| parser.add_argument("--wandb_project", type=str, default="qwen3vl-mimic-finetune-8bit") |
| parser.add_argument("--wandb_run_name", type=str, default="") |
| parser.add_argument("--wandb_entity", type=str, default="") |
|
|
| return parser.parse_args() |
|
|
|
|
| def extract_findings_impression(text: str) -> Optional[str]: |
| """Extract and return only the FINDINGS and IMPRESSION sections from a report. |
| |
| Handles all MIMIC report formats: |
| - Inline content: "FINDINGS: The heart size is..." |
| - Newline content: "FINDINGS:\n\n The heart size is..." |
| - Either ordering: IMPRESSION before FINDINGS (rare, ~22 files) |
| - Skips: "PROVISIONAL FINDINGS IMPRESSION (PFI):" header |
| |
| Strategy: scan all section-header positions first, then slice content |
| between them. This avoids regex look-ahead issues with both inline and |
| newline-separated content. |
| """ |
| import re |
|
|
| text = text.replace("\r\n", "\n").replace("\r", "\n") |
|
|
| |
| |
| |
| |
| boundary = re.compile(r"^[ \t]*([A-Za-z][A-Za-z ()/]{3,}):[ \t]*", re.MULTILINE) |
|
|
| |
| bounds: list = [] |
| for m in boundary.finditer(text): |
| name = m.group(1).strip().upper() |
| |
| if "PROVISIONAL" in name: |
| continue |
| bounds.append((name, m.start(), m.end())) |
|
|
| found: dict = {} |
| for i, (name, _header_start, content_start) in enumerate(bounds): |
| if name == "FINDINGS": |
| key = "Findings" |
| elif name == "IMPRESSION": |
| key = "Impression" |
| else: |
| continue |
|
|
| if key in found: |
| continue |
|
|
| content_end = bounds[i + 1][1] if i + 1 < len(bounds) else len(text) |
| raw = text[content_start:content_end] |
| lines = [ |
| line.strip() for line in raw.splitlines() |
| if line.strip() and not re.match(r"^[_\-=]{5,}$", line.strip()) |
| ] |
| content = "\n".join(lines) |
| if content: |
| found[key] = content |
|
|
| if not found: |
| return None |
|
|
| parts: list = [] |
| for heading in ("Findings", "Impression"): |
| if heading in found: |
| parts.append(f"{heading}:\n{found[heading]}") |
|
|
| return "\n\n".join(parts) if parts else None |
|
|
|
|
| def clean_report_text(text: str) -> str: |
| extracted = extract_findings_impression(text) |
| if extracted: |
| return extracted.strip() |
| |
| lines = [line.strip() for line in text.splitlines()] |
| non_empty = [line for line in lines if line] |
| return "\n".join(non_empty).strip() |
|
|
|
|
| def get_study_image_paths(dataset_root: Path, images_glob: str, study_id: str) -> List[Path]: |
| image_paths: List[Path] = [] |
| image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} |
|
|
| for images_part in sorted(dataset_root.glob(images_glob)): |
| if not images_part.is_dir(): |
| continue |
| study_dir = images_part / study_id |
| if not study_dir.exists() or not study_dir.is_dir(): |
| continue |
| for image_path in sorted(study_dir.iterdir()): |
| if image_path.suffix.lower() in image_extensions: |
| image_paths.append(image_path) |
|
|
| return image_paths |
|
|
|
|
| def build_samples( |
| dataset_root: Path, |
| reports_dir_name: str, |
| images_glob: str, |
| min_report_chars: int, |
| max_images_per_study: int, |
| ) -> List[Dict[str, object]]: |
| reports_dir = dataset_root / reports_dir_name |
| if not reports_dir.exists(): |
| raise FileNotFoundError(f"Reports folder not found: {reports_dir}") |
|
|
| report_files = sorted(reports_dir.glob("*.txt")) |
| if not report_files: |
| raise FileNotFoundError(f"No .txt reports found in: {reports_dir}") |
|
|
| samples: List[Dict[str, object]] = [] |
|
|
| for report_path in report_files: |
| study_id = report_path.stem |
| report_text = clean_report_text(report_path.read_text(encoding="utf-8", errors="ignore")) |
| if len(report_text) < min_report_chars: |
| continue |
|
|
| image_paths = get_study_image_paths(dataset_root, images_glob, study_id) |
| if not image_paths: |
| continue |
|
|
| if max_images_per_study > 0: |
| image_paths = image_paths[:max_images_per_study] |
|
|
| samples.append( |
| { |
| "study_id": study_id, |
| "image_paths": [str(path) for path in image_paths], |
| "report_text": report_text, |
| } |
| ) |
|
|
| if not samples: |
| raise RuntimeError("No valid (image, report) samples were built.") |
|
|
| return samples |
|
|
|
|
| def split_by_study( |
| samples: List[Dict[str, object]], |
| val_ratio: float, |
| seed: int, |
| ) -> Tuple[List[Dict[str, object]], List[Dict[str, object]]]: |
| study_ids = sorted({s["study_id"] for s in samples}) |
| rng = random.Random(seed) |
| rng.shuffle(study_ids) |
|
|
| val_count = max(1, int(len(study_ids) * val_ratio)) if val_ratio > 0 else 0 |
| val_ids = set(study_ids[:val_count]) |
|
|
| train_samples = [s for s in samples if s["study_id"] not in val_ids] |
| val_samples = [s for s in samples if s["study_id"] in val_ids] |
| return train_samples, val_samples |
|
|
|
|
| def _build_messages( |
| images: List[Image.Image], |
| report_text: str, |
| instruction: str, |
| system_prompt: Optional[str], |
| ) -> Dict[str, List[Dict]]: |
| messages: List[Dict] = [] |
| if system_prompt: |
| messages.append( |
| { |
| "role": "system", |
| "content": [{"type": "text", "text": system_prompt}], |
| } |
| ) |
|
|
| user_content: List[Dict] = [{"type": "text", "text": instruction}] |
| user_content.extend({"type": "image", "image": image} for image in images) |
|
|
| messages.extend( |
| [ |
| { |
| "role": "user", |
| "content": user_content, |
| }, |
| { |
| "role": "assistant", |
| "content": [{"type": "text", "text": report_text}], |
| }, |
| ] |
| ) |
|
|
| return {"messages": messages} |
|
|
|
|
| def _build_inference_messages( |
| images: List[Image.Image], |
| instruction: str, |
| system_prompt: Optional[str], |
| ) -> List[Dict]: |
| messages: List[Dict] = [] |
| if system_prompt: |
| messages.append( |
| { |
| "role": "system", |
| "content": [{"type": "text", "text": system_prompt}], |
| } |
| ) |
|
|
| user_content: List[Dict] = [{"type": "text", "text": instruction}] |
| user_content.extend({"type": "image", "image": image} for image in images) |
|
|
| messages.append( |
| { |
| "role": "user", |
| "content": user_content, |
| } |
| ) |
| return messages |
|
|
|
|
| def _load_rgb_images(image_paths: List[str], split_name: str) -> List[Image.Image]: |
| loaded_images: List[Image.Image] = [] |
| for image_path in image_paths: |
| try: |
| with Image.open(image_path) as opened_image: |
| loaded_images.append(opened_image.convert("RGB")) |
| except (OSError, UnidentifiedImageError, ValueError) as error: |
| print(f"[{split_name}] Runtime unreadable image: {image_path} ({error})") |
|
|
| if not loaded_images: |
| loaded_images = [Image.new("RGB", (224, 224), color=(0, 0, 0))] |
| return loaded_images |
|
|
|
|
| def generate_eval_report( |
| model, |
| tokenizer, |
| image_paths: List[str], |
| instruction: str, |
| system_prompt: Optional[str], |
| max_new_tokens: int = 256, |
| ) -> str: |
| images = _load_rgb_images(image_paths, split_name="eval") |
| messages = _build_inference_messages(images, instruction, system_prompt) |
|
|
| try: |
| prompt_text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| except Exception as error: |
| return f"<prompt build failed: {error}>" |
|
|
| tokenization_errors: List[str] = [] |
| inputs = None |
| tokenization_attempts = [ |
| {"images": images}, |
| {"images": [images]}, |
| {"image": images}, |
| {"image": [images]}, |
| ] |
| for image_argument in tokenization_attempts: |
| try: |
| inputs = tokenizer( |
| text=[prompt_text], |
| padding=True, |
| return_tensors="pt", |
| **image_argument, |
| ) |
| break |
| except Exception as error: |
| tokenization_errors.append(str(error)) |
|
|
| if inputs is None: |
| return f"<tokenization failed: {' | '.join(tokenization_errors)}>" |
|
|
| model_device = next(model.parameters()).device |
| inputs = { |
| key: value.to(model_device) if isinstance(value, torch.Tensor) else value |
| for key, value in inputs.items() |
| } |
|
|
| outputs = None |
| try: |
| with torch.inference_mode(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| ) |
| input_token_count = inputs["input_ids"].shape[-1] if "input_ids" in inputs else 0 |
| generated_ids = outputs[:, input_token_count:] |
| generated_text = tokenizer.batch_decode( |
| generated_ids, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False, |
| )[0] |
| return clean_report_text(generated_text) |
| except torch.OutOfMemoryError as error: |
| return f"<generation OOM: {error}>" |
| except Exception as error: |
| return f"<generation failed: {error}>" |
| finally: |
| if outputs is not None: |
| del outputs |
| if inputs is not None: |
| del inputs |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
|
|
| class EvalSampleGenerationCallback(TrainerCallback): |
| def __init__( |
| self, |
| model, |
| tokenizer, |
| eval_samples: List[Dict[str, object]], |
| instruction: str, |
| system_prompt: Optional[str], |
| seed: int, |
| output_dir: str, |
| sample_count: int, |
| max_new_tokens: int, |
| csv_filename: str = "eval_random_samples.csv", |
| ) -> None: |
| self.model = model |
| self.tokenizer = tokenizer |
| self.eval_samples = eval_samples |
| self.instruction = instruction |
| self.system_prompt = system_prompt |
| self.rng = random.Random(seed) |
| self.sample_count = max(0, sample_count) |
| self.max_new_tokens = max(1, max_new_tokens) |
| self.output_csv_path = Path(output_dir) / csv_filename |
| self.output_csv_path.parent.mkdir(parents=True, exist_ok=True) |
| if not self.output_csv_path.exists(): |
| with self.output_csv_path.open("w", newline="", encoding="utf-8") as file_handle: |
| writer = csv.DictWriter( |
| file_handle, |
| fieldnames=[ |
| "global_step", |
| "study_id", |
| "image_count", |
| "image_paths", |
| "original_report", |
| "generated_report", |
| ], |
| ) |
| writer.writeheader() |
|
|
| def on_evaluate(self, args, state, control, **kwargs): |
| if not self.eval_samples or self.sample_count <= 0: |
| return control |
|
|
| try: |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
| |
| |
| try: |
| FastVisionModel.for_inference(self.model) |
| except Exception as _mode_err: |
| print(f"[eval] for_inference switch failed (non-fatal): {_mode_err}") |
|
|
| sample_count = min(self.sample_count, len(self.eval_samples)) |
| sampled_items = self.rng.sample(self.eval_samples, sample_count) |
|
|
| print(f"\n[eval@step={int(state.global_step)}] Random {sample_count} generation samples") |
| rows_to_write: List[Dict[str, str]] = [] |
| for sample in sampled_items: |
| try: |
| sample_image_paths = [str(path) for path in sample["image_paths"]] |
| generated_report = generate_eval_report( |
| model=self.model, |
| tokenizer=self.tokenizer, |
| image_paths=sample_image_paths, |
| instruction=self.instruction, |
| system_prompt=self.system_prompt, |
| max_new_tokens=self.max_new_tokens, |
| ) |
| print(f"id: {sample['study_id']}") |
| print(f"image_count: {len(sample_image_paths)}") |
| print(f"image_paths: {' | '.join(sample_image_paths)}") |
| print(f"original: {sample['report_text']}") |
| print(f"generated: {generated_report}") |
| print("-" * 80) |
|
|
| rows_to_write.append( |
| { |
| "global_step": str(int(state.global_step)), |
| "study_id": str(sample["study_id"]), |
| "image_count": str(len(sample_image_paths)), |
| "image_paths": " | ".join(sample_image_paths), |
| "original_report": str(sample["report_text"]), |
| "generated_report": generated_report, |
| } |
| ) |
| except torch.OutOfMemoryError as sample_error: |
| print(f"[eval] Skipping one sample due to OOM: {sample_error}") |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| except Exception as sample_error: |
| print(f"[eval] Skipping one sample due to error: {sample_error}") |
| finally: |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| if rows_to_write: |
| with self.output_csv_path.open("a", newline="", encoding="utf-8") as file_handle: |
| writer = csv.DictWriter( |
| file_handle, |
| fieldnames=[ |
| "global_step", |
| "study_id", |
| "image_count", |
| "image_paths", |
| "original_report", |
| "generated_report", |
| ], |
| ) |
| writer.writerows(rows_to_write) |
|
|
| print(f"Saved eval generations to: {self.output_csv_path}") |
| except Exception as eval_error: |
| print(f"[eval] Callback failed but training will continue: {eval_error}") |
| finally: |
| |
| |
| |
| |
| try: |
| FastVisionModel.for_training(self.model) |
| except Exception as _mode_err: |
| print(f"[eval] for_training switch failed (non-fatal): {_mode_err}") |
| try: |
| torch._dynamo.reset() |
| except Exception as _dynamo_err: |
| print(f"[eval] dynamo reset failed (non-fatal): {_dynamo_err}") |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return control |
|
|
|
|
| def load_image_validity_cache(cache_path: Path) -> Dict[str, bool]: |
| if not cache_path.exists(): |
| return {} |
| try: |
| data = json.loads(cache_path.read_text(encoding="utf-8")) |
| except (OSError, ValueError, json.JSONDecodeError): |
| return {} |
| if not isinstance(data, dict): |
| return {} |
| return {str(key): bool(value) for key, value in data.items()} |
|
|
|
|
| def save_image_validity_cache(cache_path: Path, cache: Dict[str, bool]) -> None: |
| cache_path.parent.mkdir(parents=True, exist_ok=True) |
| cache_path.write_text(json.dumps(cache), encoding="utf-8") |
|
|
|
|
| def filter_readable_samples( |
| samples: List[Dict[str, object]], |
| cache_path: Path, |
| split_name: str, |
| ) -> List[Dict[str, object]]: |
| cache = load_image_validity_cache(cache_path) |
|
|
| filtered: List[Dict[str, object]] = [] |
| skipped = 0 |
| newly_checked = 0 |
|
|
| for sample in samples: |
| valid_image_paths: List[str] = [] |
| for image_path in sample["image_paths"]: |
| image_path = str(image_path) |
| is_valid = cache.get(image_path) |
|
|
| if is_valid is None: |
| newly_checked += 1 |
| try: |
| with Image.open(image_path) as opened_image: |
| opened_image.verify() |
| is_valid = True |
| except (OSError, UnidentifiedImageError, ValueError): |
| is_valid = False |
| cache[image_path] = is_valid |
|
|
| if is_valid: |
| valid_image_paths.append(image_path) |
|
|
| if valid_image_paths: |
| sample["image_paths"] = valid_image_paths |
| filtered.append(sample) |
| else: |
| skipped += 1 |
|
|
| save_image_validity_cache(cache_path, cache) |
| print( |
| f"[{split_name}] Kept {len(filtered)} / {len(samples)} samples, skipped {skipped} corrupt images " |
| f"(newly checked: {newly_checked})." |
| ) |
| return filtered |
|
|
|
|
| def build_hf_dataset( |
| samples: List[Dict[str, object]], |
| ) -> Dataset: |
| rows = [ |
| { |
| "image_paths": sample["image_paths"], |
| "report_text": sample["report_text"], |
| } |
| for sample in samples |
| ] |
| return Dataset.from_list(rows) |
|
|
|
|
| def attach_lazy_vision_transform( |
| dataset: Dataset, |
| instruction: str, |
| system_prompt: Optional[str], |
| split_name: str, |
| ) -> Dataset: |
| def transform(examples: Dict[str, object]) -> Dict[str, List[Dict]]: |
| image_paths_batch = examples["image_paths"] |
| report_texts = examples["report_text"] |
| is_batch = isinstance(image_paths_batch, list) and bool(image_paths_batch) and isinstance(image_paths_batch[0], list) |
|
|
| if not is_batch: |
| image_paths_batch = [image_paths_batch] |
| report_texts = [report_texts] |
|
|
| messages_batch: List[List[Dict]] = [] |
| for image_paths, report_text in zip(image_paths_batch, report_texts): |
| images = _load_rgb_images([str(path) for path in image_paths], split_name=split_name) |
|
|
| messages_batch.append( |
| _build_messages( |
| images=images, |
| report_text=str(report_text), |
| instruction=instruction, |
| system_prompt=system_prompt, |
| )["messages"] |
| ) |
|
|
| if is_batch: |
| return {"messages": messages_batch} |
| return {"messages": messages_batch[0]} |
|
|
| dataset.set_transform(transform) |
| return dataset |
|
|
|
|
| class SafeVisionDataCollator(UnslothVisionDataCollator): |
| """Wraps UnslothVisionDataCollator to skip samples that cause image-token |
| count mismatches (typically from truncation of multi-image sequences that |
| exceed max_length). Instead of crashing the whole training run, the |
| offending item is replaced by a randomly-chosen item from the same batch |
| and a warning is printed once per unique culprit. |
| """ |
|
|
| def __init__(self, model, tokenizer, max_seq_length: Optional[int] = None): |
| super().__init__(model, tokenizer, max_seq_length=max_seq_length) |
| self._warned: set = set() |
|
|
| def __call__(self, features): |
| try: |
| return super().__call__(features) |
| except ValueError as exc: |
| msg = str(exc) |
| if "Mismatch in `image` token count" not in msg: |
| raise |
| |
| good_features = self._filter_bad_samples(features) |
| if not good_features: |
| raise |
| return super().__call__(good_features) |
|
|
| def _filter_bad_samples(self, features): |
| good = [] |
| for item in features: |
| try: |
| super().__call__([item]) |
| good.append(item) |
| except ValueError as exc: |
| key = str(exc)[:120] |
| if key not in self._warned: |
| self._warned.add(key) |
| print( |
| f"\n[SafeVisionDataCollator] Skipping 1 sample that causes truncation " |
| f"mismatch (will not warn again for identical error):\n {key}\n" |
| ) |
| return good |
|
|
|
|
| def print_gpu_memory_stats(prefix: str) -> None: |
| if not torch.cuda.is_available(): |
| print(f"[{prefix}] CUDA not available.") |
| return |
|
|
| gpu_stats = torch.cuda.get_device_properties(0) |
| max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) |
| used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) |
| print(f"[{prefix}] GPU = {gpu_stats.name}") |
| print(f"[{prefix}] Max GPU memory = {max_memory} GB") |
| print(f"[{prefix}] Reserved memory = {used_memory} GB") |
|
|
|
|
| _REQUIRED_CHECKPOINT_FILES = {"trainer_state.json"} |
|
|
|
|
| def find_latest_checkpoint(output_dir: Path) -> Optional[str]: |
| if not output_dir.exists(): |
| return None |
|
|
| checkpoint_paths = [ |
| path for path in output_dir.glob("checkpoint-*") |
| if path.is_dir() and path.name.split("-")[-1].isdigit() |
| ] |
| if not checkpoint_paths: |
| return None |
|
|
| |
| checkpoint_paths.sort(key=lambda path: int(path.name.split("-")[-1]), reverse=True) |
|
|
| for checkpoint_path in checkpoint_paths: |
| missing = [f for f in _REQUIRED_CHECKPOINT_FILES if not (checkpoint_path / f).exists()] |
| if missing: |
| print( |
| f"[checkpoint] Skipping incomplete checkpoint {checkpoint_path.name} " |
| f"(missing: {', '.join(missing)})" |
| ) |
| continue |
| return str(checkpoint_path) |
|
|
| print("[checkpoint] No valid (complete) checkpoints found.") |
| return None |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| if args.no_8bit: |
| raise ValueError("--no_8bit is not supported in finetune_8bit.py. Use finetune_4bit.py for 4bit training.") |
|
|
| args.load_in_8bit = True |
|
|
| if "4bit" in args.model_name.lower(): |
| raise ValueError("--model_name appears to be a 4bit model. Please provide a non-4bit model for this script.") |
|
|
| if args.val_ratio < 0 or args.val_ratio >= 1: |
| raise ValueError("--val_ratio must be in [0, 1).") |
|
|
| if args.num_train_epochs <= 0: |
| raise ValueError("--num_train_epochs must be > 0.") |
|
|
| if args.max_grad_norm <= 0: |
| raise ValueError("--max_grad_norm must be > 0.") |
|
|
| if args.max_train_retries < 0: |
| raise ValueError("--max_train_retries must be >= 0.") |
|
|
| if args.retry_wait_seconds < 0: |
| raise ValueError("--retry_wait_seconds must be >= 0.") |
|
|
| if args.eval_sample_count < 0: |
| raise ValueError("--eval_sample_count must be >= 0.") |
|
|
| if args.eval_max_new_tokens <= 0: |
| raise ValueError("--eval_max_new_tokens must be > 0.") |
|
|
| if args.cuda_device < 0: |
| raise ValueError("--cuda_device must be >= 0.") |
|
|
| try: |
| torch._dynamo.config.suppress_errors = True |
| except Exception: |
| pass |
|
|
| if args.use_wandb: |
| import os |
|
|
| os.environ["WANDB_PROJECT"] = args.wandb_project |
| |
| run_name = args.wandb_run_name or ( |
| f"lr{args.learning_rate:.0e}" |
| f"_ep{args.num_train_epochs:.0f}" |
| f"_bs{args.batch_size}x{args.grad_accum}" |
| f"_r{args.lora_r}a{args.lora_alpha}" |
| f"_warm{args.warmup_steps}" |
| f"_gc{args.max_grad_norm}" |
| ) |
| os.environ["WANDB_NAME"] = run_name |
| print(f"W&B run name: {run_name}") |
| if args.wandb_entity: |
| os.environ["WANDB_ENTITY"] = args.wandb_entity |
|
|
| print(f"Using epoch-based training for {args.num_train_epochs} epochs.") |
| print(f"Using gradient clipping max_grad_norm={args.max_grad_norm}.") |
| print("Torch compile is disabled for stability with 8-bit bitsandbytes kernels.") |
|
|
| random.seed(args.seed) |
| torch.manual_seed(args.seed) |
|
|
| dataset_root = Path(args.dataset_root) |
| if not dataset_root.exists(): |
| raise FileNotFoundError(f"Dataset root not found: {dataset_root}") |
|
|
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA is required for 8-bit training, but no CUDA device is available.") |
|
|
| device_count = torch.cuda.device_count() |
| if args.cuda_device >= device_count: |
| raise ValueError( |
| f"--cuda_device={args.cuda_device} is out of range. Visible CUDA devices: {device_count}" |
| ) |
|
|
| torch.cuda.set_device(args.cuda_device) |
|
|
| train_device_index = torch.cuda.current_device() |
| quantized_device_map = {"": train_device_index} |
|
|
| print(f"Using CUDA device index: {train_device_index}") |
|
|
| print("Loading model...") |
| model, tokenizer = FastVisionModel.from_pretrained( |
| args.model_name, |
| load_in_4bit=False, |
| load_in_8bit=True, |
| use_gradient_checkpointing="unsloth", |
| device_map=quantized_device_map, |
| ) |
|
|
| model = FastVisionModel.get_peft_model( |
| model, |
| finetune_vision_layers=True, |
| finetune_language_layers=True, |
| finetune_attention_modules=True, |
| finetune_mlp_modules=True, |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| bias="none", |
| random_state=args.seed, |
| use_rslora=False, |
| loftq_config=None, |
| ) |
|
|
| print("Building paired image-report samples...") |
| samples = build_samples( |
| dataset_root=dataset_root, |
| reports_dir_name=args.reports_dir, |
| images_glob=args.images_glob, |
| min_report_chars=args.min_report_chars, |
| max_images_per_study=args.max_images_per_study, |
| ) |
|
|
| train_samples, val_samples = split_by_study(samples, args.val_ratio, args.seed) |
|
|
| if args.max_train_samples > 0: |
| train_samples = train_samples[: args.max_train_samples] |
| if args.max_val_samples > 0: |
| val_samples = val_samples[: args.max_val_samples] |
|
|
| if args.skip_image_verification: |
| print("Skipping image verification step as requested.") |
| else: |
| cache_path = ( |
| Path(args.image_validity_cache) |
| if args.image_validity_cache |
| else dataset_root / ".image_validity_cache.json" |
| ) |
| print(f"Verifying image readability with cache: {cache_path}") |
| train_samples = filter_readable_samples(train_samples, cache_path, split_name="train") |
| if val_samples: |
| val_samples = filter_readable_samples(val_samples, cache_path, split_name="val") |
|
|
| print(f"Total samples: {len(samples)}") |
| print(f"Train samples: {len(train_samples)}") |
| print(f"Val samples: {len(val_samples)}") |
|
|
| system_prompt = args.system_prompt.strip() if args.system_prompt else "" |
| if not system_prompt: |
| system_prompt = None |
|
|
| train_dataset = build_hf_dataset(train_samples) |
| train_dataset = attach_lazy_vision_transform(train_dataset, args.instruction, system_prompt, split_name="train") |
| eval_dataset = ( |
| attach_lazy_vision_transform(build_hf_dataset(val_samples), args.instruction, system_prompt, split_name="val") |
| if val_samples |
| else None |
| ) |
|
|
| print(f"Final train dataset size: {len(train_dataset)}") |
| if eval_dataset is not None: |
| print(f"Final val dataset size: {len(eval_dataset)}") |
|
|
| FastVisionModel.for_training(model) |
|
|
| config_kwargs = { |
| "per_device_train_batch_size": args.batch_size, |
| "per_device_eval_batch_size": args.batch_size, |
| "gradient_accumulation_steps": args.grad_accum, |
| "max_grad_norm": args.max_grad_norm, |
| "warmup_steps": args.warmup_steps, |
| "learning_rate": args.learning_rate, |
| "logging_steps": args.logging_steps, |
| "optim": "adamw_8bit", |
| "weight_decay": 0.001, |
| "lr_scheduler_type": "cosine", |
| "seed": args.seed, |
| "output_dir": args.output_dir, |
| "report_to": "wandb" if args.use_wandb else "none", |
| "save_steps": args.save_steps, |
| "save_total_limit": 2, |
| "remove_unused_columns": False, |
| "dataset_text_field": "", |
| "dataset_kwargs": {"skip_prepare_dataset": True}, |
| "max_length": args.max_length, |
| "num_train_epochs": args.num_train_epochs, |
| "dataloader_num_workers": args.dataloader_num_workers, |
| } |
|
|
| if eval_dataset is not None: |
| config_kwargs.update( |
| { |
| "eval_strategy": "steps", |
| "eval_steps": args.save_steps, |
| } |
| ) |
| else: |
| config_kwargs["eval_strategy"] = "no" |
|
|
| callbacks = [] |
| if eval_dataset is not None and val_samples: |
| sample_count = args.eval_sample_count if args.eval_sample_count > 0 else 3 |
| callbacks.append( |
| EvalSampleGenerationCallback( |
| model=model, |
| tokenizer=tokenizer, |
| eval_samples=val_samples, |
| instruction=args.instruction, |
| system_prompt=system_prompt, |
| seed=args.seed, |
| output_dir=args.output_dir, |
| sample_count=sample_count, |
| max_new_tokens=args.eval_max_new_tokens, |
| ) |
| ) |
| if args.eval_sample_count == 0: |
| print(f"eval_sample_count=0 overridden to 3 (always print samples during eval).") |
|
|
| trainer = SFTTrainer( |
| model=model, |
| tokenizer=tokenizer, |
| data_collator=SafeVisionDataCollator(model, tokenizer, max_seq_length=args.max_length), |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| args=SFTConfig(**config_kwargs), |
| callbacks=callbacks, |
| ) |
|
|
| print_gpu_memory_stats("BEFORE TRAIN") |
|
|
| trainer_stats = None |
| retry_attempt = 0 |
| if args.resume_from_checkpoint: |
| resume_path = Path(args.resume_from_checkpoint) |
| if not resume_path.exists() or not resume_path.is_dir(): |
| raise FileNotFoundError(f"--resume_from_checkpoint not found or not a directory: {resume_path}") |
| resume_checkpoint = str(resume_path) |
| else: |
| resume_checkpoint = find_latest_checkpoint(Path(args.output_dir)) |
|
|
| while retry_attempt <= args.max_train_retries: |
| try: |
| if resume_checkpoint: |
| print(f"Resuming training from checkpoint: {resume_checkpoint}") |
| trainer_stats = trainer.train(resume_from_checkpoint=resume_checkpoint) |
| else: |
| trainer_stats = trainer.train() |
| break |
| except Exception as error: |
| retry_attempt += 1 |
| print("\n[train] Caught training exception. Attempting automatic recovery...") |
| print(f"[train] Retry {retry_attempt} / {args.max_train_retries}") |
| print(f"[train] Exception: {error}") |
| traceback.print_exc() |
|
|
| try: |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| except Exception: |
| pass |
|
|
| |
| |
| |
| |
| if resume_checkpoint is not None: |
| resume_path = Path(resume_checkpoint) |
| missing = [ |
| f for f in _REQUIRED_CHECKPOINT_FILES |
| if not (resume_path / f).exists() |
| ] |
| if missing and resume_path.exists(): |
| print( |
| f"[train] Removing incomplete checkpoint {resume_path.name} " |
| f"(missing: {', '.join(missing)}) before retry." |
| ) |
| import shutil |
| try: |
| shutil.rmtree(str(resume_path)) |
| except Exception as _rm_err: |
| print(f"[train] Could not remove incomplete checkpoint: {_rm_err}") |
|
|
| resume_checkpoint = find_latest_checkpoint(Path(args.output_dir)) |
| if retry_attempt > args.max_train_retries: |
| raise |
|
|
| if args.retry_wait_seconds > 0: |
| print(f"[train] Waiting {args.retry_wait_seconds}s before retry...") |
| time.sleep(args.retry_wait_seconds) |
|
|
| if trainer_stats is None: |
| raise RuntimeError("Training did not produce stats after retry attempts.") |
|
|
| print_gpu_memory_stats("AFTER TRAIN") |
|
|
| print("Train metrics:") |
| print(trainer_stats.metrics) |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| model.save_pretrained(str(output_dir)) |
| tokenizer.save_pretrained(str(output_dir)) |
|
|
| print(f"Saved LoRA adapter + tokenizer to: {output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |