jjsprockel commited on
Commit
b782f98
·
verified ·
1 Parent(s): 73da3e4

Add Colab notebook for LUAD subtype inference (MedGemma-27B QLoRA)

Browse files
notebooks/MedGemma27B_LUAD_inference.ipynb ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "0c2f74ec",
6
+ "metadata": {},
7
+ "source": [
8
+ "\n",
9
+ "# 🩺 MedGemma-27B (QLoRA) — Inference Notebook (Colab, A100 80GB)\n",
10
+ "This notebook loads the base **`google/medgemma-27b-it`** model and your **QLoRA adapter** **`jjsprockel/medgemma27b-luad-qlora`** to predict **subtipos de adenocarcinoma de pulmón** a partir de una imagen H&E.\n",
11
+ "\n",
12
+ "**Requisitos recomendados en Colab Pro/Pro+:**\n",
13
+ "- **GPU:** A100 **80 GB** (Runtime → Change runtime type → GPU → A100; luego *Reconnect*).\n",
14
+ "- **Python:** 3.10+ \n",
15
+ "- **Transformers:** 4.44+\n",
16
+ "\n",
17
+ "> Nota: Si el repo de Hugging Face es privado, inicia sesión con tu token en la celda correspondiente.\n"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "id": "99aaec3f",
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "\n",
28
+ "import torch, platform, sys, subprocess, os\n",
29
+ "\n",
30
+ "print(\"Python:\", sys.version)\n",
31
+ "print(\"PyTorch:\", torch.__version__)\n",
32
+ "print(\"CUDA available:\", torch.cuda.is_available())\n",
33
+ "if torch.cuda.is_available():\n",
34
+ " print(\"GPU name:\", torch.cuda.get_device_name(0))\n",
35
+ " print(\"Total VRAM (GB):\", round(torch.cuda.get_device_properties(0).total_memory / 1e9, 2))\n",
36
+ "\n",
37
+ "# Strongly suggest A100 80GB\n",
38
+ "if torch.cuda.is_available():\n",
39
+ " vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)\n",
40
+ " if vram_gb < 70:\n",
41
+ " print(\"\\n[WARNING] Detected <70 GB VRAM. 4-bit quantization is enabled, but you may still hit OOM with very large images.\")\n",
42
+ "else:\n",
43
+ " print(\"[WARNING] No GPU detected. Please switch to a GPU runtime (A100 preferred).\")\n"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "id": "05184873",
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "\n",
54
+ "%%bash\n",
55
+ "pip -q install --upgrade pip\n",
56
+ "pip -q install 'transformers>=4.44.2' 'accelerate>=0.34.2' 'bitsandbytes>=0.43.3' 'peft>=0.12.0' 'huggingface_hub>=0.24.6' 'safetensors>=0.4.4' 'Pillow' 'torchvision'\n"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "id": "71b17a36",
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "\n",
67
+ "# OPTIONAL: Only needed if your repos are private.\n",
68
+ "# from huggingface_hub import login\n",
69
+ "# login() # <- paste your HF token when prompted\n",
70
+ "pass\n"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "0f914903",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "\n",
81
+ "from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig\n",
82
+ "from peft import PeftModel\n",
83
+ "import torch, io, json, re, requests\n",
84
+ "from PIL import Image\n",
85
+ "from typing import Optional\n",
86
+ "\n",
87
+ "# ---- IDs ----\n",
88
+ "BASE_ID = \"google/medgemma-27b-it\"\n",
89
+ "ADAPTER_ID = \"jjsprockel/medgemma27b-luad-qlora\"\n",
90
+ "\n",
91
+ "# ---- Class list ----\n",
92
+ "SUBTYPES = [\"lepidic\",\"acinar\",\"papillary\",\"micropapillary\",\"solid\",\"invasive mucinous\",\"colloid\",\"fetal\",\"enteric\"]\n",
93
+ "\n",
94
+ "# ---- Quantization (4-bit) ----\n",
95
+ "bnb_cfg = BitsAndBytesConfig(\n",
96
+ " load_in_4bit=True,\n",
97
+ " bnb_4bit_use_double_quant=True,\n",
98
+ " bnb_4bit_quant_type=\"nf4\",\n",
99
+ " bnb_4bit_compute_dtype=torch.bfloat16\n",
100
+ ")\n",
101
+ "\n",
102
+ "# ---- Load base and adapter ----\n",
103
+ "print(\"Loading base model:\", BASE_ID)\n",
104
+ "base = AutoModelForImageTextToText.from_pretrained(\n",
105
+ " BASE_ID,\n",
106
+ " quantization_config=bnb_cfg,\n",
107
+ " device_map={\"\": \"cuda\"},\n",
108
+ " torch_dtype=torch.bfloat16,\n",
109
+ " low_cpu_mem_usage=True,\n",
110
+ ")\n",
111
+ "\n",
112
+ "print(\"Attaching adapter:\", ADAPTER_ID)\n",
113
+ "model = PeftModel.from_pretrained(base, ADAPTER_ID).eval()\n",
114
+ "processor = AutoProcessor.from_pretrained(BASE_ID)\n",
115
+ "\n",
116
+ "# ---- Prompt templates ----\n",
117
+ "SYSTEM_PROMPT = (\n",
118
+ " \"You are an expert pulmonary pathologist. Return ONLY JSON with key 'subtype' strictly from: \"\n",
119
+ " + \", \".join(SUBTYPES) + \".\"\n",
120
+ ")\n",
121
+ "USER_PROMPT = \"Predict the subtype for this H&E lung adenocarcinoma patch. Only JSON.\"\n",
122
+ "\n",
123
+ "def load_image_from_url(url: str) -> Image.Image:\n",
124
+ " r = requests.get(url, timeout=30)\n",
125
+ " r.raise_for_status()\n",
126
+ " return Image.open(io.BytesIO(r.content)).convert(\"RGB\")\n",
127
+ "\n",
128
+ "def load_image_from_path(path: str) -> Image.Image:\n",
129
+ " return Image.open(path).convert(\"RGB\")\n",
130
+ "\n",
131
+ "def run_inference(img: Image.Image, max_new_tokens: int = 32) -> str:\n",
132
+ " messages = [\n",
133
+ " {\"role\":\"system\",\"content\":[{\"type\":\"text\",\"text\":SYSTEM_PROMPT}]},\n",
134
+ " {\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":USER_PROMPT},{\"type\":\"image\",\"image\":img}]}\n",
135
+ " ]\n",
136
+ " templ = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n",
137
+ " enc = processor(text=templ, images=img, return_tensors=\"pt\")\n",
138
+ " inputs = {\n",
139
+ " \"input_ids\": enc[\"input_ids\"].to(model.device),\n",
140
+ " \"attention_mask\": enc[\"attention_mask\"].to(model.device),\n",
141
+ " \"pixel_values\": enc[\"pixel_values\"].to(model.device, dtype=torch.bfloat16),\n",
142
+ " }\n",
143
+ " with torch.inference_mode(), torch.amp.autocast(\"cuda\", dtype=torch.bfloat16):\n",
144
+ " out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)[0]\n",
145
+ " gen = out[inputs[\"input_ids\"].shape[-1]:]\n",
146
+ " decoded = processor.decode(gen, skip_special_tokens=True)\n",
147
+ " return decoded\n",
148
+ "\n",
149
+ "def try_parse_json(s: str) -> Optional[dict]:\n",
150
+ " # Extract a JSON-looking object if extra tokens sneak in\n",
151
+ " m = re.search(r'\\{.*\\}', s, flags=re.DOTALL)\n",
152
+ " if m:\n",
153
+ " try:\n",
154
+ " return json.loads(m.group(0))\n",
155
+ " except Exception:\n",
156
+ " return None\n",
157
+ " return None\n",
158
+ "\n",
159
+ "print(\"Ready ✅\")\n"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "id": "ddd98a70",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "\n",
170
+ "# ===== Option A: Load image from URL =====\n",
171
+ "IMAGE_URL = \"\" # <-- Paste a direct image URL here (e.g. a PNG/JPG of H&E patch)\n",
172
+ "img = None\n",
173
+ "if IMAGE_URL:\n",
174
+ " img = load_image_from_url(IMAGE_URL)\n",
175
+ " display(img)\n",
176
+ "\n",
177
+ "# ===== Option B: Upload from your computer (Colab UI) =====\n",
178
+ "if img is None:\n",
179
+ " try:\n",
180
+ " from google.colab import files # type: ignore\n",
181
+ " up = files.upload()\n",
182
+ " assert len(up) > 0, \"No file was uploaded.\"\n",
183
+ " fname = list(up.keys())[0]\n",
184
+ " img = load_image_from_path(fname)\n",
185
+ " display(img)\n",
186
+ " except Exception as e:\n",
187
+ " raise SystemExit(f\"Please provide a valid IMAGE_URL or upload an image. Error: {e}\")\n",
188
+ "\n",
189
+ "# ---- Run inference ----\n",
190
+ "raw = run_inference(img, max_new_tokens=32)\n",
191
+ "print(\"\\nRaw model output:\")\n",
192
+ "print(raw)\n",
193
+ "\n",
194
+ "maybe = try_parse_json(raw)\n",
195
+ "if maybe and isinstance(maybe, dict) and \"subtype\" in maybe:\n",
196
+ " print(\"\\nParsed JSON:\")\n",
197
+ " print(json.dumps(maybe, indent=2))\n",
198
+ " print(\"\\nPredicted subtype:\", maybe.get(\"subtype\"))\n",
199
+ "else:\n",
200
+ " print(\"\\n[WARNING] Could not parse a clean JSON payload. Review the raw output above.\")\n"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "id": "57e74a19",
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "\n",
211
+ "# (Optional) Batch inference from a list of URLs.\n",
212
+ "URLS = [\n",
213
+ " # \"https://example.org/patch1.jpg\",\n",
214
+ " # \"https://example.org/patch2.png\",\n",
215
+ "]\n",
216
+ "\n",
217
+ "results = []\n",
218
+ "for url in URLS:\n",
219
+ " try:\n",
220
+ " im = load_image_from_url(url)\n",
221
+ " out = run_inference(im)\n",
222
+ " parsed = try_parse_json(out) or {\"raw\": out}\n",
223
+ " results.append({\"url\": url, **parsed})\n",
224
+ " print(f\"[OK] {url} ->\", parsed)\n",
225
+ " except Exception as e:\n",
226
+ " print(f\"[ERROR] {url}: {e}\")\n",
227
+ "\n",
228
+ "# If you want to save results to JSON:\n",
229
+ "# import json, time\n",
230
+ "# ts = int(time.time())\n",
231
+ "# with open(f\"batch_results_{ts}.json\", \"w\") as f:\n",
232
+ "# json.dump(results, f, indent=2)\n"
233
+ ]
234
+ }
235
+ ],
236
+ "metadata": {},
237
+ "nbformat": 4,
238
+ "nbformat_minor": 5
239
+ }