Shree2604 commited on
Commit
361b4d2
Β·
verified Β·
1 Parent(s): f0f2c4a

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +402 -341
server.py CHANGED
@@ -2,6 +2,7 @@ import io
2
  import torch
3
  import torch.nn as nn
4
  import timm
 
5
  import traceback
6
  import os
7
  from PIL import Image
@@ -12,7 +13,7 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer
12
  from huggingface_hub import hf_hub_download
13
 
14
  # ─────────────────────────────────────────────────────────────────────────────
15
- # CONFIGURATION - Matching Colab Notebook Exactly
16
  # ─────────────────────────────────────────────────────────────────────────────
17
  CONFIG = {
18
  'coatnet_model': 'coatnet_1_rw_224',
@@ -20,8 +21,6 @@ CONFIG = {
20
  'img_emb_dim': 768,
21
  'train_last_stages': 2,
22
  'image_size': 224,
23
- 'max_length': 100,
24
- 'num_beams': 4,
25
  }
26
 
27
  # ─────────────────────────────────────────────────────────────────────────────
@@ -31,17 +30,18 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  print(f"πŸ–₯️ Using device: {device}")
32
 
33
  # ─────────────────────────────────────────────────────────────────────────────
34
- # LOAD TOKENIZER - Matching Colab
35
  # ─────────────────────────────────────────────────────────────────────────────
 
36
  print("\n" + "="*80)
37
- print("LOADING TOKENIZER")
38
  print("="*80)
 
 
39
  tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model'])
40
  print(f"βœ“ Loaded tokenizer: {CONFIG['t5_model']}")
41
 
42
- # ─────────────────────────────────────────────────────────────────────────────
43
- # IMAGE TRANSFORM - Matching Colab Exactly
44
- # ─────────────────────────────────────────────────────────────────────────────
45
  transform = transforms.Compose([
46
  transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
47
  transforms.ToTensor(),
@@ -52,429 +52,490 @@ transform = transforms.Compose([
52
  ])
53
  print(f"βœ“ Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})")
54
 
 
 
 
 
 
55
  # ─────────────────────────────────────────────────────────────────────────────
56
- # ARCHITECTURE 1: CoAtNetEncoder - Exactly from Colab SECTION 6
 
57
  # ─────────────────────────────────────────────────────────────────────────────
58
  class CoAtNetEncoder(nn.Module):
59
- def __init__(self, model_name="coatnet_1_rw_224", pretrained=True, train_last_stages=2):
60
  super().__init__()
61
- self.encoder = timm.create_model(
62
- model_name,
63
- pretrained=pretrained,
64
- num_classes=0,
65
- global_pool="avg"
66
- )
67
 
68
- # Freeze all parameters
69
- for p in self.encoder.parameters():
70
- p.requires_grad = False
 
 
 
71
 
72
- # Unfreeze last stages
73
- if hasattr(self.encoder, "stages") and train_last_stages is not None:
74
- stages = self.encoder.stages
75
- for stage in stages[-train_last_stages:]:
76
- for p in stage.parameters():
77
- p.requires_grad = True
 
 
 
78
 
79
  def forward(self, x):
80
- return self.encoder(x)
 
 
 
81
 
82
 
83
  # ─────────────────────────────────────────────────────────────────────────────
84
- # ARCHITECTURE 2: VisionT5Model - Exactly from Colab SECTION 6
 
 
85
  # ─────────────────────────────────────────────────────────────────────────────
86
- class VisionT5Model(nn.Module):
87
  def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768):
88
  super().__init__()
89
-
90
- # Vision encoder (CoAtNet)
91
  self.img_encoder = img_encoder
92
-
93
- # Text decoder (T5)
94
  self.t5 = T5ForConditionalGeneration.from_pretrained(txt_model_name)
95
-
96
- # Projection layer to match image features with T5 d_model
97
  self.proj = nn.Linear(img_emb_dim, self.t5.config.d_model)
98
 
99
- # Freeze shared T5 embeddings for faster and stable training
100
  for p in self.t5.shared.parameters():
101
  p.requires_grad = False
102
 
103
- def forward(self, pixel_values, input_ids, attention_mask, labels=None):
104
- # Extract image features
105
- img_feats = self.img_encoder(pixel_values)
 
 
 
 
 
 
 
 
 
106
 
107
- # Project image features to T5 embedding space
108
- img_feats = self.proj(img_feats)
 
109
 
110
- # Add sequence dimension
111
- encoder_hidden_states = img_feats.unsqueeze(1)
 
 
 
 
 
 
 
 
 
112
 
113
- # Run T5 encoder using image embeddings
114
- encoder_outputs = self.t5.encoder(
115
- inputs_embeds=encoder_hidden_states
116
- )
 
 
 
 
 
 
117
 
118
- # Run T5 decoder and compute loss
119
- outputs = self.t5(
120
- encoder_outputs=encoder_outputs,
121
- attention_mask=torch.ones(
122
- encoder_hidden_states.size()[:2], device=device
123
- ),
124
- input_ids=input_ids,
125
- labels=labels,
126
- )
127
- return outputs
128
-
129
- def generate_reports(self, pixel_values, max_length=100, num_beams=4):
130
- """
131
- Generate reports - EXACTLY matching Colab SECTION 6
132
- """
133
- # Extract and project image features
134
- img_feats = self.img_encoder(pixel_values)
135
- img_feats = self.proj(img_feats)
136
- encoder_hidden_states = img_feats.unsqueeze(1)
137
-
138
- # Encode image features
139
- encoder_outputs = self.t5.encoder(
140
- inputs_embeds=encoder_hidden_states
141
- )
142
 
143
- # Generate report using beam search - EXACT parameters from Colab
144
- generated_ids = self.t5.generate(
145
- encoder_outputs=encoder_outputs,
146
- attention_mask=torch.ones(
147
- encoder_hidden_states.size()[:2], device=device
148
- ),
149
- max_length=max_length,
150
- num_beams=num_beams,
151
- early_stopping=True
152
- )
 
 
153
 
154
- return generated_ids
 
 
 
 
155
 
 
 
156
 
157
- print("βœ“ Model architecture classes defined")
 
 
 
158
 
159
- # ─────────────────────────────────────────────────────────────────────────────
160
- # MODEL LOADING FUNCTION - Exactly from Colab SECTION 8
161
- # ─────────────────────────────────────────────────────────────────────────────
162
- def load_model_from_checkpoint(checkpoint_path: str, model_name: str, config: dict):
163
- """
164
- Load VisionT5Model from checkpoint - EXACT implementation from Colab
165
- """
166
- print(f"\nLoading {model_name} model...")
167
- print(f" Checkpoint: {checkpoint_path}")
 
168
 
169
- try:
170
- # Create image encoder
171
- print(f" Creating CoAtNet encoder: {config['coatnet_model']}")
172
- img_encoder = CoAtNetEncoder(
173
- model_name=config['coatnet_model'],
174
- pretrained=False, # Weights will come from checkpoint
175
- train_last_stages=config['train_last_stages']
176
- )
 
 
177
 
178
- # Create full model
179
- print(f" Creating VisionT5 model with T5: {config['t5_model']}")
180
- model = VisionT5Model(
181
- img_encoder=img_encoder,
182
- txt_model_name=config['t5_model'],
183
- img_emb_dim=config['img_emb_dim']
184
- )
185
 
186
- # Load checkpoint
187
- print(f" Loading checkpoint weights...")
188
- checkpoint = torch.load(checkpoint_path, map_location=device)
189
-
190
- # Handle different checkpoint formats
191
- if isinstance(checkpoint, dict):
192
- if 'model_state_dict' in checkpoint:
193
- state_dict = checkpoint['model_state_dict']
194
- print(f" Found 'model_state_dict' in checkpoint")
195
- elif 'state_dict' in checkpoint:
196
- state_dict = checkpoint['state_dict']
197
- print(f" Found 'state_dict' in checkpoint")
198
- elif 'model' in checkpoint:
199
- state_dict = checkpoint['model']
200
- print(f" Found 'model' in checkpoint")
201
- else:
202
- # Assume checkpoint is the state dict
203
- state_dict = checkpoint
204
- print(f" Using checkpoint as state_dict directly")
205
-
206
- # Print additional checkpoint info if available
207
- if 'epoch' in checkpoint:
208
- print(f" Checkpoint epoch: {checkpoint['epoch']}")
209
- if 'loss' in checkpoint:
210
- print(f" Checkpoint loss: {checkpoint['loss']:.4f}")
211
- else:
212
- state_dict = checkpoint
213
- print(f" Checkpoint is a state_dict")
214
-
215
- # Load state dict
216
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
217
-
218
- if missing_keys:
219
- print(f" ⚠️ Missing keys: {len(missing_keys)}")
220
- if len(missing_keys) <= 5:
221
- for key in missing_keys:
222
- print(f" - {key}")
223
-
224
- if unexpected_keys:
225
- print(f" ⚠️ Unexpected keys: {len(unexpected_keys)}")
226
- if len(unexpected_keys) <= 5:
227
- for key in unexpected_keys:
228
- print(f" - {key}")
229
-
230
- # Move to device and set to eval mode
231
- model = model.to(device)
232
- model.eval()
233
-
234
- print(f"βœ“ {model_name} model loaded successfully!")
235
- return model
236
 
237
- except Exception as e:
238
- print(f"❌ Error loading {model_name} model: {str(e)}")
239
- import traceback
240
- traceback.print_exc()
241
- raise
 
 
 
242
 
243
 
244
  # ─────────────────────────────────────────────────────────────────────────────
245
- # INFERENCE FUNCTION - Exactly from Colab SECTION 9
 
246
  # ─────────────────────────────────────────────────────────────────────────────
247
- def generate_report(
248
- image_path: str,
249
- model: VisionT5Model,
250
- config: dict
251
- ) -> str:
252
  """
253
- Generate medical report from X-ray image - EXACT implementation from Colab
 
 
 
 
 
 
 
 
 
 
254
  """
255
- try:
256
- # Preprocess image
257
- image = Image.open(image_path).convert('RGB')
258
- pixel_values = transform(image).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
- # Generate report - using EXACT parameters from Colab
261
- with torch.no_grad():
262
- generated_ids = model.generate_reports(
263
- pixel_values,
264
- max_length=config['max_length'],
265
- num_beams=config['num_beams']
 
 
 
 
 
 
 
 
 
 
 
 
266
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
- # Decode
269
- report = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
270
 
271
- return report.strip()
 
 
 
 
 
 
 
 
 
272
 
273
- except Exception as e:
274
- print(f"Error generating report for {image_path}: {str(e)}")
275
- return ""
276
 
277
 
278
  # ─────────────────────────────────────────────────────────────────────────────
279
- # LOAD MODELS FROM HUGGINGFACE
280
  # ─────────────────────────────────────────────────────────────────────────────
281
- print("\n" + "="*80)
282
- print("LOADING MODELS FROM HUGGINGFACE")
283
- print("="*80)
284
-
285
- # Download model files from Hugging Face
286
- try:
287
- SFT_MODEL_PATH = hf_hub_download(
288
- repo_id="vinaykumarhs2020/RLHF_radiology_model",
289
- filename="best_model.pt"
290
- )
291
- PPO_MODEL_PATH = hf_hub_download(
292
- repo_id="vinaykumarhs2020/RLHF_radiology_model",
293
- filename="rlhf_model.pt"
 
 
 
 
 
 
 
 
 
 
 
 
294
  )
295
- print(f"βœ“ Downloaded SFT model: {SFT_MODEL_PATH}")
296
- print(f"βœ“ Downloaded PPO model: {PPO_MODEL_PATH}")
297
- except Exception as e:
298
- print(f"❌ Error downloading models: {e}")
299
- # Fallback to local paths if downloads fail
300
- SFT_MODEL_PATH = "/content/best_model.pt"
301
- PPO_MODEL_PATH = "/content/rlhf_model.pt"
302
- print(f"⚠️ Using local paths instead")
303
-
304
- # Load both models
305
- print("\n" + "="*80)
306
- print("LOADING MODELS")
307
- print("="*80)
308
-
309
- sft_model = load_model_from_checkpoint(
310
- SFT_MODEL_PATH,
311
- "SFT",
312
- CONFIG
313
- )
314
-
315
- ppo_model = load_model_from_checkpoint(
316
- PPO_MODEL_PATH,
317
- "PPO",
318
- CONFIG
319
- )
320
 
321
- print("\nβœ“ Both models loaded successfully!")
322
 
323
  # ─────────────────────────────────────────────────────────────────────────────
324
  # FASTAPI APP
325
  # ─────────────────────────────────────────────────────────────────────────────
326
- app = FastAPI(title="Medical Report Generation - Matching Colab")
327
 
328
  app.add_middleware(
329
  CORSMiddleware,
330
- allow_origins=["*"],
331
  allow_methods=["*"],
332
  allow_headers=["*"],
333
  )
334
 
335
 
336
- def preprocess_bytes(file_bytes: bytes) -> torch.Tensor:
337
- """Preprocess image bytes for inference"""
338
- img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
339
- return transform(img).unsqueeze(0).to(device)
340
-
341
-
342
  @app.get("/health")
343
  def health():
344
- return {
345
- "status": "ok",
346
- "device": str(device),
347
- "models_loaded": True,
348
- "config": CONFIG
349
- }
350
 
351
 
352
  @app.post("/sft")
353
  async def sft_inference(file: UploadFile = File(...)):
354
- """
355
- SFT model inference - EXACTLY matching Colab behavior
356
- """
357
  try:
358
- # Preprocess image
359
- tensor = preprocess_bytes(await file.read())
360
-
361
- # Generate report using EXACT Colab parameters
362
- with torch.no_grad():
363
- generated_ids = sft_model.generate_reports(
364
- tensor,
365
- max_length=CONFIG['max_length'],
366
- num_beams=CONFIG['num_beams']
367
- )
368
-
369
- # Decode - EXACTLY as Colab does
370
- report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
371
-
372
  print(f"[SFT] Generated: {report}")
373
-
374
- # Return FULL report without truncation
375
- return {"report": report, "model": "SFT", "config_used": CONFIG}
376
-
377
  except Exception as e:
378
  traceback.print_exc()
379
- return {"report": f"ERROR: {str(e)}", "model": "SFT"}
380
 
381
 
382
- @app.post("/ppo")
383
- async def ppo_inference(file: UploadFile = File(...)):
384
- """
385
- PPO model inference - EXACTLY matching Colab behavior
386
- """
387
  try:
388
- # Preprocess image
389
- tensor = preprocess_bytes(await file.read())
390
-
391
- # Generate report using EXACT Colab parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  with torch.no_grad():
393
- generated_ids = ppo_model.generate_reports(
394
- tensor,
395
- max_length=CONFIG['max_length'],
396
- num_beams=CONFIG['num_beams']
397
- )
398
-
399
- # Decode - EXACTLY as Colab does
400
- report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
401
-
402
- print(f"[PPO] Generated: {report}")
403
-
404
- # Return FULL report without truncation
405
- return {"report": report, "model": "PPO", "config_used": CONFIG}
406
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  except Exception as e:
408
  traceback.print_exc()
409
- return {"report": f"ERROR: {str(e)}", "model": "PPO"}
410
 
411
 
412
- @app.post("/compare")
413
- async def compare_models(file: UploadFile = File(...)):
414
- """
415
- Generate reports from both models for comparison
416
- """
417
  try:
418
- file_bytes = await file.read()
419
- tensor = preprocess_bytes(file_bytes)
420
-
421
- # SFT Generation
422
- with torch.no_grad():
423
- sft_ids = sft_model.generate_reports(
424
- tensor,
425
- max_length=CONFIG['max_length'],
426
- num_beams=CONFIG['num_beams']
427
- )
428
- sft_report = tokenizer.decode(sft_ids[0], skip_special_tokens=True).strip()
429
-
430
- # PPO Generation
431
- with torch.no_grad():
432
- ppo_ids = ppo_model.generate_reports(
433
- tensor,
434
- max_length=CONFIG['max_length'],
435
- num_beams=CONFIG['num_beams']
436
- )
437
- ppo_report = tokenizer.decode(ppo_ids[0], skip_special_tokens=True).strip()
438
-
439
- print(f"[COMPARE] SFT: {sft_report}")
440
- print(f"[COMPARE] PPO: {ppo_report}")
441
-
442
- return {
443
- "sft_report": sft_report,
444
- "ppo_report": ppo_report,
445
- "config_used": CONFIG
446
- }
447
-
448
  except Exception as e:
449
  traceback.print_exc()
450
- return {
451
- "sft_report": f"ERROR: {str(e)}",
452
- "ppo_report": f"ERROR: {str(e)}"
453
- }
454
-
455
-
456
- @app.get("/debug_config")
457
- def debug_config():
458
- """Debug endpoint to check configuration"""
459
- return {
460
- "config": CONFIG,
461
- "device": str(device),
462
- "tokenizer": CONFIG['t5_model'],
463
- "image_size": CONFIG['image_size'],
464
- "max_length": CONFIG['max_length'],
465
- "num_beams": CONFIG['num_beams'],
466
- "models_loaded": {
467
- "sft": sft_model is not None,
468
- "ppo": ppo_model is not None
469
- }
470
- }
471
 
472
 
473
  # ─────────────────────────────────────────────────────────────────────────────
474
- # STATIC FILE SERVING
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  # ─────────────────────────────────────────────────────────────────────────────
476
  from fastapi.staticfiles import StaticFiles
 
477
 
 
478
  if os.path.exists("build"):
479
  app.mount("/", StaticFiles(directory="build", html=True), name="static")
480
  print("βœ… React app mounted at /")
 
2
  import torch
3
  import torch.nn as nn
4
  import timm
5
+ import pickle
6
  import traceback
7
  import os
8
  from PIL import Image
 
13
  from huggingface_hub import hf_hub_download
14
 
15
  # ─────────────────────────────────────────────────────────────────────────────
16
+ # CONFIGURATION
17
  # ─────────────────────────────────────────────────────────────────────────────
18
  CONFIG = {
19
  'coatnet_model': 'coatnet_1_rw_224',
 
21
  'img_emb_dim': 768,
22
  'train_last_stages': 2,
23
  'image_size': 224,
 
 
24
  }
25
 
26
  # ─────────────────────────────────────────────────────────────────────────────
 
30
  print(f"πŸ–₯️ Using device: {device}")
31
 
32
  # ─────────────────────────────────────────────────────────────────────────────
33
+ # SECTION 7: Load Tokenizer and Image Transform
34
  # ─────────────────────────────────────────────────────────────────────────────
35
+
36
  print("\n" + "="*80)
37
+ print("LOADING TOKENIZER AND IMAGE TRANSFORM")
38
  print("="*80)
39
+
40
+ # Load tokenizer
41
  tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model'])
42
  print(f"βœ“ Loaded tokenizer: {CONFIG['t5_model']}")
43
 
44
+ # Define image transform
 
 
45
  transform = transforms.Compose([
46
  transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
47
  transforms.ToTensor(),
 
52
  ])
53
  print(f"βœ“ Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})")
54
 
55
+ def preprocess_image(image_path: str) -> torch.Tensor:
56
+ """Load and preprocess image."""
57
+ image = Image.open(image_path).convert('RGB')
58
+ return transform(image)
59
+
60
  # ─────────────────────────────────────────────────────────────────────────────
61
+ # ARCHITECTURE 1 β€” CoAtNet Encoder (shared by all three models)
62
+ # Matches BOTH notebooks exactly.
63
  # ─────────────────────────────────────────────────────────────────────────────
64
  class CoAtNetEncoder(nn.Module):
65
+ def __init__(self, model_name=None, pretrained=False, train_last_stages=None):
66
  super().__init__()
67
+ # Use CONFIG defaults if not specified
68
+ model_name = model_name or CONFIG['coatnet_model']
69
+ train_last_stages = train_last_stages or CONFIG['train_last_stages']
70
+
71
+ # pretrained=False at inference time β€” weights come from .pt file
72
+ self.backbone = timm.create_model(model_name, pretrained=pretrained)
73
 
74
+ for name, param in self.backbone.named_parameters():
75
+ param.requires_grad = False
76
+ for i in range(5 - train_last_stages, 5):
77
+ if f"stages.{i}" in name:
78
+ param.requires_grad = True
79
+ break
80
 
81
+ # Detect feature_dim dynamically (same as RM/PPO notebook Cell 4)
82
+ with torch.no_grad():
83
+ dummy = torch.randn(1, 3, 224, 224)
84
+ features = self.backbone.forward_features(dummy)
85
+ if len(features.shape) == 4:
86
+ features = features.mean(dim=[2, 3])
87
+ self.feature_dim = features.shape[-1]
88
+
89
+ print(f" CoAtNetEncoder feature_dim = {self.feature_dim}")
90
 
91
  def forward(self, x):
92
+ features = self.backbone.forward_features(x)
93
+ if len(features.shape) == 4:
94
+ features = features.mean(dim=[2, 3])
95
+ return features
96
 
97
 
98
  # ─────────────────────────────────────────────────────────────────────────────
99
+ # ARCHITECTURE 2 β€” SFT VisionT5Model
100
+ # BUG FIX: Uses self.t5 and self.proj β€” exactly matching best_model.pt keys
101
+ # from SFT notebook Cell 33. Do NOT rename these to txt_model/img_proj.
102
  # ─────────────────────────────────────────────────────────────────────────────
103
+ class SFTVisionT5Model(nn.Module):
104
  def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768):
105
  super().__init__()
 
 
106
  self.img_encoder = img_encoder
107
+ # ← self.t5 (NOT self.txt_model β€” must match saved keys)
 
108
  self.t5 = T5ForConditionalGeneration.from_pretrained(txt_model_name)
109
+ # ← self.proj (NOT self.img_proj β€” must match saved keys)
 
110
  self.proj = nn.Linear(img_emb_dim, self.t5.config.d_model)
111
 
 
112
  for p in self.t5.shared.parameters():
113
  p.requires_grad = False
114
 
115
+ def generate_reports(self, pixel_values, max_length=100):
116
+ self.eval()
117
+ with torch.no_grad():
118
+ # Extract + project image features
119
+ img_feats = self.img_encoder(pixel_values) # [B, feature_dim]
120
+ img_feats = self.proj(img_feats) # [B, d_model]
121
+ encoder_hidden_states = img_feats.unsqueeze(1) # [B, 1, d_model]
122
+
123
+ # Encode
124
+ encoder_outputs = self.t5.encoder(
125
+ inputs_embeds=encoder_hidden_states
126
+ )
127
 
128
+ attn = torch.ones(
129
+ encoder_hidden_states.size()[:2], device=pixel_values.device
130
+ )
131
 
132
+ # BUG FIX 3: repetition_penalty + no_repeat_ngram_size breaks
133
+ # the "Projection: Projection: Projection:" loop
134
+ generated_ids = self.t5.generate(
135
+ encoder_outputs=encoder_outputs,
136
+ attention_mask=attn,
137
+ max_length=max_length,
138
+ num_beams=4,
139
+ early_stopping=True,
140
+ no_repeat_ngram_size=3,
141
+ repetition_penalty=1.3,
142
+ )
143
 
144
+ reports = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
145
+ # Strip any leading "Projection: X." prefix that leaked from training data
146
+ cleaned = []
147
+ for r in reports:
148
+ if r.lower().startswith("projection:"):
149
+ # Remove the first "Projection: X." segment
150
+ parts = r.split(".", 1)
151
+ r = parts[1].strip() if len(parts) > 1 else r
152
+ cleaned.append(r)
153
+ return cleaned
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ # ─────────────────────────────────────────────────────────────────────────────
157
+ # ARCHITECTURE 3 β€” PPO VisionT5Model
158
+ # Uses self.txt_model and self.img_proj β€” matching RM/PPO notebook Cell 4.
159
+ # ─────────────────────────────────────────────────────────────────────────────
160
+ class PPOVisionT5Model(nn.Module):
161
+ def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768):
162
+ super().__init__()
163
+ self.img_encoder = img_encoder
164
+ # ← self.txt_model (matches PPO notebook Cell 4)
165
+ self.txt_model = T5ForConditionalGeneration.from_pretrained(txt_model_name)
166
+ # ← self.img_proj (matches PPO notebook Cell 4)
167
+ self.img_proj = nn.Linear(img_emb_dim, self.txt_model.config.d_model)
168
 
169
+ def generate_reports(self, images, max_length=128):
170
+ self.eval()
171
+ with torch.no_grad():
172
+ img_features = self.img_encoder(images) # [B, feature_dim]
173
+ img_emb = self.img_proj(img_features).unsqueeze(1) # [B, 1, d_model]
174
 
175
+ batch_size = images.size(0)
176
+ img_attn = torch.ones(batch_size, 1, device=images.device)
177
 
178
+ encoder_outputs = self.txt_model.encoder(
179
+ inputs_embeds=img_emb,
180
+ attention_mask=img_attn
181
+ )
182
 
183
+ # BUG FIX 3: same repetition guards as SFT
184
+ generated = self.txt_model.generate(
185
+ encoder_outputs=encoder_outputs,
186
+ attention_mask=img_attn,
187
+ max_length=max_length,
188
+ num_beams=4,
189
+ early_stopping=True,
190
+ no_repeat_ngram_size=3,
191
+ repetition_penalty=1.3,
192
+ )
193
 
194
+ reports = tokenizer.batch_decode(generated, skip_special_tokens=True)
195
+ # Strip any leading "Projection: X." prefix that leaked from training data
196
+ cleaned = []
197
+ for r in reports:
198
+ if r.lower().startswith("projection:"):
199
+ # Remove the first "Projection: X." segment
200
+ parts = r.split(".", 1)
201
+ r = parts[1].strip() if len(parts) > 1 else r
202
+ cleaned.append(r)
203
+ return cleaned
204
 
 
 
 
 
 
 
 
205
 
206
+ # ─────────────────────────────────────────────────────────────────────────────
207
+ # ARCHITECTURE 4 β€” Reward Model
208
+ # Matches RM/PPO notebook Cell 5 exactly.
209
+ # ─────────────────────────────────────────────────────────────────────────────
210
+ class RewardModel(nn.Module):
211
+ def __init__(self, img_encoder, txt_model_name="t5-small"):
212
+ super().__init__()
213
+ self.img_encoder = img_encoder
214
+ self.txt_encoder = T5ForConditionalGeneration.from_pretrained(txt_model_name).encoder
215
+ img_dim = img_encoder.feature_dim
216
+ txt_dim = self.txt_encoder.config.d_model
217
+ self.img_proj = nn.Linear(img_dim, 512)
218
+ self.txt_proj = nn.Linear(txt_dim, 512)
219
+ self.reward_head = nn.Sequential(
220
+ nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.1),
221
+ nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.1),
222
+ nn.Linear(256, 1)
223
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ def forward(self, images, input_ids, attention_mask):
226
+ img_features = self.img_encoder(images)
227
+ img_emb = self.img_proj(img_features)
228
+ txt_outputs = self.txt_encoder(input_ids=input_ids, attention_mask=attention_mask)
229
+ txt_emb = txt_outputs.last_hidden_state.mean(dim=1)
230
+ txt_emb = self.txt_proj(txt_emb)
231
+ combined = torch.cat([img_emb, txt_emb], dim=1)
232
+ return self.reward_head(combined).squeeze(-1)
233
 
234
 
235
  # ─────────────────────────────────────────────────────────────────────────────
236
+ # MODEL LOADER β€” handles both .pt (state_dict) and .pkl (full model)
237
+ # Prints a key-match diagnostic so you can see exactly what loaded.
238
  # ─────────────────────────────────────────────────────────────────────────────
239
+ def remap_keys(raw_sd: dict, label: str) -> dict:
 
 
 
 
240
  """
241
+ Remap state_dict keys to match current model attribute names.
242
+
243
+ Known mismatches discovered from diagnostic output:
244
+ SFT notebook used:
245
+ img_encoder.encoder.* β†’ we use img_encoder.backbone.*
246
+ t5.* β†’ we use t5.* (already correct for SFTVisionT5Model)
247
+ proj.* β†’ we use proj.* (already correct for SFTVisionT5Model)
248
+ PPO/RM notebooks used:
249
+ img_encoder.backbone.* β†’ already correct βœ…
250
+ txt_model.* β†’ already correct βœ…
251
+ img_proj.* β†’ already correct βœ…
252
  """
253
+ remapped = {}
254
+ changed = 0
255
+ for k, v in raw_sd.items():
256
+ new_k = k
257
+ # SFT encoder used self.encoder, our CoAtNetEncoder uses self.backbone
258
+ if "img_encoder.encoder." in new_k:
259
+ new_k = new_k.replace("img_encoder.encoder.", "img_encoder.backbone.")
260
+ changed += 1
261
+ remapped[new_k] = v
262
+ if changed:
263
+ print(f" πŸ”§ Remapped {changed} keys: img_encoder.encoder.* β†’ img_encoder.backbone.*")
264
+ return remapped
265
+
266
+
267
+ def load_model(path: str, model_obj: nn.Module, label: str) -> nn.Module:
268
+ print(f"\nπŸ“‚ Loading {label} from: {path}")
269
+
270
+ if path.endswith(".pkl"):
271
+ with open(path, "rb") as f:
272
+ loaded = pickle.load(f)
273
+ print(f" βœ… Loaded full pickle object: {type(loaded)}")
274
+ return loaded.to(device)
275
+
276
+ # .pt state_dict
277
+ raw_sd = torch.load(path, map_location=device)
278
+
279
+ # Print first 5 saved keys for diagnosis
280
+ saved_keys = list(raw_sd.keys())
281
+ print(f" Saved keys (first 5): {saved_keys[:5]}")
282
+ model_keys = list(model_obj.state_dict().keys())
283
+ print(f" Model keys (first 5): {model_keys[:5]}")
284
+
285
+ # Remap any mismatched key prefixes
286
+ raw_sd = remap_keys(raw_sd, label)
287
+
288
+ result = model_obj.load_state_dict(raw_sd, strict=False)
289
+
290
+ # Ignore known-safe missing keys:
291
+ # head.fc.* - classification head, intentionally removed (num_classes=0)
292
+ # num_batches_tracked - BatchNorm counter, not a learned weight
293
+ SAFE_MISSING = ("num_batches_tracked", "head.fc.")
294
+ missing = [k for k in result.missing_keys if not any(s in k for s in SAFE_MISSING)]
295
+ unexpected = [k for k in result.unexpected_keys if "num_batches_tracked" not in k]
296
+
297
+ if missing:
298
+ print(f" Missing keys: {missing[:5]}{'...' if len(missing)>5 else ''}")
299
+ print(f" WARNING: {len(missing)} missing keys - weights NOT loaded for those layers!")
300
+ if unexpected:
301
+ print(f" Unexpected keys: {unexpected[:5]}{'...' if len(unexpected)>5 else ''}")
302
+ if not missing and not unexpected:
303
+ print(f" OK: All keys matched perfectly!")
304
+
305
+ return model_obj.to(device)
306
 
307
+
308
+ # ─────────────────────────────────────────────────────────────────────────────
309
+ # LOAD ALL THREE MODELS FROM HUGGING FACE HUB
310
+ # Models are downloaded from Shree2604/BioStack repository
311
+ # ─────────────────────────────────────────────────────────────────────────────
312
+ def download_model_from_hf(model_filename: str, local_path: str = "models/") -> str:
313
+ """Download model from Hugging Face Hub if not exists locally"""
314
+ os.makedirs(local_path, exist_ok=True)
315
+ full_path = os.path.join(local_path, model_filename)
316
+
317
+ if not os.path.exists(full_path):
318
+ print(f" Downloading {model_filename} from Hugging Face Hub...")
319
+ try:
320
+ downloaded_path = hf_hub_download(
321
+ repo_id="Shree2604/BioStack",
322
+ filename=model_filename,
323
+ local_dir=local_path,
324
+ local_dir_use_symlinks=False
325
  )
326
+ print(f" Downloaded {model_filename}")
327
+ return downloaded_path
328
+ except Exception as e:
329
+ print(f" Failed to download {model_filename}: {e}")
330
+ raise
331
+ else:
332
+ print(f" Using local {model_filename}")
333
+ return full_path
334
+
335
+ print("\n" + "="*60)
336
+ print(" LOADING MODELS FROM HUGGING FACE HUB")
337
+ print("="*60)
338
+
339
+ # Download models from Hugging Face
340
+ SFT_MODEL_PATH = download_model_from_hf("best_model.pt")
341
+ REWARD_MODEL_PATH = download_model_from_hf("reward_model.pt")
342
+ PPO_MODEL_PATH = download_model_from_hf("rlhf_model.pt")
343
+
344
+ # SFT
345
+ _sft_enc = CoAtNetEncoder(pretrained=False)
346
+ sft_model = load_model(SFT_MODEL_PATH, SFTVisionT5Model(_sft_enc), "SFT Model")
347
+ sft_model.eval()
348
+
349
+ # Reward
350
+ _rm_enc = CoAtNetEncoder(pretrained=False)
351
+ reward_model = load_model(REWARD_MODEL_PATH, RewardModel(_rm_enc), "Reward Model")
352
+ reward_model.eval()
353
+
354
+ # PPO
355
+ _ppo_enc = CoAtNetEncoder(pretrained=False)
356
+ ppo_model = load_model(PPO_MODEL_PATH, PPOVisionT5Model(_ppo_enc), "PPO Model")
357
+ ppo_model.eval()
358
+
359
+ print("\n All models loaded and ready!\n" + "="*60 + "\n")
360
 
 
 
361
 
362
+ # ────────────────────────────────────────────────────────────────��────────────
363
+ # IMAGE PREPROCESSING
364
+ # Matches BOTH notebooks: RGB, 224Γ—224, ImageNet normalisation
365
+ # ─────────────────────────────────────────────────────────────────────────────
366
+ transform = transforms.Compose([
367
+ transforms.Resize((224, 224)),
368
+ transforms.ToTensor(),
369
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
370
+ std=[0.229, 0.224, 0.225])
371
+ ])
372
 
373
+ def preprocess(file_bytes: bytes) -> torch.Tensor:
374
+ img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
375
+ return transform(img).unsqueeze(0).to(device) # [1, 3, 224, 224]
376
 
377
 
378
  # ─────────────────────────────────────────────────────────────────────────────
379
+ # REWARD FEEDBACK GENERATOR
380
  # ─────────────────────────────────────────────────────────────────────────────
381
+ KEY_MEDICAL_TERMS = [
382
+ 'lung', 'heart', 'normal', 'clear', 'opacity', 'infiltrate',
383
+ 'cardiomegaly', 'pleural', 'pulmonary', 'chest', 'thorax',
384
+ 'pneumonia', 'edema', 'effusion', 'consolidation'
385
+ ]
386
+
387
+ def reward_feedback(report: str, score: float) -> str:
388
+ rl = report.lower()
389
+ present = [t for t in KEY_MEDICAL_TERMS if t in rl]
390
+ missing = [t for t in KEY_MEDICAL_TERMS if t not in rl]
391
+ words = len(report.split())
392
+ length_q = "good" if 50 <= words <= 150 else ("too short" if words < 50 else "too long")
393
+
394
+ # Quality factor assessments based on the score and analysis
395
+ terminology_score = len(present) / len(KEY_MEDICAL_TERMS)
396
+ completeness_score = min(1.0, words / 100.0) # Rough estimate based on length
397
+ structure_score = 1.0 if 50 <= words <= 150 else 0.5 # Good structure if proper length
398
+ radiological_score = score # The overall score represents alignment
399
+
400
+ return (
401
+ f"Reward Score: {score:.2f} | "
402
+ f"Quality Factors - "
403
+ f"Medical Terminology: {terminology_score:.1%} | "
404
+ f"Clinical Completeness: {completeness_score:.1%} | "
405
+ f"Report Structure: {structure_score:.1%}"
406
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
 
408
 
409
  # ─────────────────────────────────────────────────────────────────────────────
410
  # FASTAPI APP
411
  # ─────────────────────────────────────────────────────────────────────────────
412
+ app = FastAPI(title="RLHF Medical Demo")
413
 
414
  app.add_middleware(
415
  CORSMiddleware,
416
+ allow_origins=["*"], # Allow all origins for Hugging Face Spaces
417
  allow_methods=["*"],
418
  allow_headers=["*"],
419
  )
420
 
421
 
 
 
 
 
 
 
422
  @app.get("/health")
423
  def health():
424
+ return {"status": "ok", "device": str(device)}
 
 
 
 
 
425
 
426
 
427
  @app.post("/sft")
428
  async def sft_inference(file: UploadFile = File(...)):
 
 
 
429
  try:
430
+ tensor = preprocess(await file.read())
431
+ report = sft_model.generate_reports(tensor)[0]
 
 
 
 
 
 
 
 
 
 
 
 
432
  print(f"[SFT] Generated: {report}")
433
+ return {"report": report[:81]}
 
 
 
434
  except Exception as e:
435
  traceback.print_exc()
436
+ return {"report": f"ERROR: {str(e)}"}
437
 
438
 
439
+ @app.post("/reward")
440
+ async def reward_inference(file: UploadFile = File(...)):
 
 
 
441
  try:
442
+ tensor = preprocess(await file.read())
443
+
444
+ # First get the SFT report to score
445
+ sft_report = sft_model.generate_reports(tensor)[0]
446
+ print(f"[REWARD] Scoring SFT report: {sft_report}")
447
+
448
+ if not sft_report.strip():
449
+ return {"score": 0.0, "feedback": "", "sft_report": ""}
450
+
451
+ enc = tokenizer(
452
+ [sft_report],
453
+ max_length=128,
454
+ padding="max_length",
455
+ truncation=True,
456
+ return_tensors="pt"
457
+ )
458
+ input_ids = enc.input_ids.to(device)
459
+ attention_mask = enc.attention_mask.to(device)
460
+
461
  with torch.no_grad():
462
+ raw_score = reward_model(tensor, input_ids, attention_mask).item()
463
+
464
+ # Detailed debug logging
465
+ print(f"[REWARD] Raw neural network output: {raw_score:.6f}")
466
+ print(f"[REWARD] Clamping to [0,1] range: max(0.0, min(1.0, {raw_score:.6f})) = {max(0.0, min(1.0, raw_score)):.6f}")
467
+
468
+ # Quality assessment details
469
+ rl = sft_report.lower()
470
+ present = [t for t in KEY_MEDICAL_TERMS if t in rl]
471
+ missing = [t for t in KEY_MEDICAL_TERMS if t not in rl]
472
+ words = len(sft_report.split())
473
+ length_q = "good" if 50 <= words <= 150 else ("too short" if words < 50 else "too long")
474
+
475
+ print(f"[REWARD] Report analysis:")
476
+ print(f" - Total words: {words} ({length_q})")
477
+ print(f" - Medical terms present ({len(present)}/{len(KEY_MEDICAL_TERMS)}): {present}")
478
+ print(f" - Medical terms missing: {missing}")
479
+ print(f" - Key terms list: {KEY_MEDICAL_TERMS}")
480
+
481
+ # Reward model architecture details
482
+ print(f"[REWARD] Model architecture:")
483
+ print(f" - CoAtNet feature dim: {reward_model.img_encoder.feature_dim}")
484
+ print(f" - T5 d_model: {reward_model.txt_encoder.config.d_model}")
485
+ print(f" - Combined feature dim: 1024 (512 img + 512 text)")
486
+ print(f" - Reward head: 1024β†’512β†’256β†’1")
487
+
488
+ # Clamped score for display
489
+ score = float(max(0.0, min(1.0, raw_score)))
490
+ feedback = reward_feedback(sft_report, score)
491
+ print(f"[REWARD] Final Score={score:.3f}")
492
+ return {"score": score, "feedback": feedback, "sft_report": sft_report}
493
+
494
  except Exception as e:
495
  traceback.print_exc()
496
+ return {"score": 0.0, "feedback": f"ERROR: {str(e)}", "sft_report": ""}
497
 
498
 
499
+ @app.post("/ppo")
500
+ async def ppo_inference(file: UploadFile = File(...)):
 
 
 
501
  try:
502
+ tensor = preprocess(await file.read())
503
+ report = ppo_model.generate_reports(tensor)[0]
504
+ print(f"[PPO] Generated: {report}")
505
+ return {"report": report}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  except Exception as e:
507
  traceback.print_exc()
508
+ return {"report": f"ERROR: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
 
510
 
511
  # ─────────────────────────────────────────────────────────────────────────────
512
+ # DIAGNOSTIC ENDPOINT β€” call GET /debug_keys to verify key names in your files
513
+ # e.g. curl http://localhost:8000/debug_keys
514
+ # ─────────────────────────────────────────────────────────────────────────────
515
+ @app.get("/debug_keys")
516
+ def debug_keys():
517
+ import os
518
+ result = {}
519
+ for label, path in [("SFT", SFT_MODEL_PATH), ("Reward", REWARD_MODEL_PATH), ("PPO", PPO_MODEL_PATH)]:
520
+ if not os.path.exists(path):
521
+ result[label] = f"FILE NOT FOUND: {path}"
522
+ continue
523
+ try:
524
+ sd = torch.load(path, map_location="cpu")
525
+ keys = list(sd.keys())
526
+ result[label] = {"first_10_keys": keys[:10], "total_keys": len(keys)}
527
+ except Exception as e:
528
+ result[label] = f"ERROR: {e}"
529
+ return result
530
+
531
+
532
+ # ─────────────────────────────────────────────────────────────────────────────
533
+ # STATIC FILE SERVING - Mount React build directory AFTER all API routes
534
  # ─────────────────────────────────────────────────────────────────────────────
535
  from fastapi.staticfiles import StaticFiles
536
+ import os
537
 
538
+ # Check if build directory exists, create fallback if needed
539
  if os.path.exists("build"):
540
  app.mount("/", StaticFiles(directory="build", html=True), name="static")
541
  print("βœ… React app mounted at /")