Shree2604 commited on
Commit
5a2d89a
·
verified ·
1 Parent(s): f3e4ffb

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +172 -105
server.py CHANGED
@@ -12,49 +12,53 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer
12
  from huggingface_hub import hf_hub_download
13
 
14
  # ─────────────────────────────────────────────────────────────────────────────
15
- # CONFIGURATION - Matching Colab Notebook Exactly
16
  # ─────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  CONFIG = {
 
18
  'coatnet_model': 'coatnet_1_rw_224',
19
  't5_model': 't5-small',
20
  'img_emb_dim': 768,
21
  'train_last_stages': 2,
 
 
22
  'image_size': 224,
 
 
23
  'max_length': 100,
24
  'num_beams': 4,
 
 
 
25
  }
26
 
27
- # ─────────────────────────────────────────────────────────────────────────────
28
- # DEVICE
29
- # ─────────────────────────────────────────────────────────────────────────────
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- print(f"🖥️ Using device: {device}")
32
 
33
  # ─────────────────────────────────────────────────────────────────────────────
34
- # LOAD TOKENIZER - Matching Colab
35
  # ─────────────────────────────────────────────────────────────────────────────
36
  print("\n" + "="*80)
37
- print("LOADING TOKENIZER")
38
  print("="*80)
39
- tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model'])
40
- print(f"✓ Loaded tokenizer: {CONFIG['t5_model']}")
41
-
42
- # ─────────────────────────────────────────────────────────────────────────────
43
- # IMAGE TRANSFORM - Matching Colab Exactly
44
- # ─────────────────────────────────────────────────────────────────────────────
45
- transform = transforms.Compose([
46
- transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
47
- transforms.ToTensor(),
48
- transforms.Normalize(
49
- mean=[0.485, 0.456, 0.406],
50
- std=[0.229, 0.224, 0.225]
51
- )
52
- ])
53
- print(f"✓ Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})")
54
 
55
- # ─────────────────────────────────────────────────────────────────────────────
56
- # ARCHITECTURE 1: CoAtNetEncoder - Exactly from Colab SECTION 6
57
- # ─────────────────────────────────────────────────────────────────────────────
58
  class CoAtNetEncoder(nn.Module):
59
  def __init__(self, model_name="coatnet_1_rw_224", pretrained=True, train_last_stages=2):
60
  super().__init__()
@@ -80,9 +84,7 @@ class CoAtNetEncoder(nn.Module):
80
  return self.encoder(x)
81
 
82
 
83
- # ─────────────────────────────────────────────────────────────────────────────
84
- # ARCHITECTURE 2: VisionT5Model - Exactly from Colab SECTION 6
85
- # ─────────────────────────────────────────────────────────────────────────────
86
  class VisionT5Model(nn.Module):
87
  def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768):
88
  super().__init__()
@@ -127,9 +129,6 @@ class VisionT5Model(nn.Module):
127
  return outputs
128
 
129
  def generate_reports(self, pixel_values, max_length=100, num_beams=4):
130
- """
131
- Generate reports - EXACTLY matching Colab SECTION 6
132
- """
133
  # Extract and project image features
134
  img_feats = self.img_encoder(pixel_values)
135
  img_feats = self.proj(img_feats)
@@ -140,7 +139,7 @@ class VisionT5Model(nn.Module):
140
  inputs_embeds=encoder_hidden_states
141
  )
142
 
143
- # Generate report using beam search - EXACT parameters from Colab
144
  generated_ids = self.t5.generate(
145
  encoder_outputs=encoder_outputs,
146
  attention_mask=torch.ones(
@@ -157,11 +156,42 @@ class VisionT5Model(nn.Module):
157
  print("✓ Model architecture classes defined")
158
 
159
  # ─────────────────────────────────────────────────────────────────────────────
160
- # MODEL LOADING FUNCTION - Exactly from Colab SECTION 8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  # ─────────────────────────────────────────────────────────────────────────────
162
  def load_model_from_checkpoint(checkpoint_path: str, model_name: str, config: dict):
163
  """
164
- Load VisionT5Model from checkpoint - EXACT implementation from Colab
 
 
 
 
 
 
 
 
 
165
  """
166
  print(f"\nLoading {model_name} model...")
167
  print(f" Checkpoint: {checkpoint_path}")
@@ -242,22 +272,36 @@ def load_model_from_checkpoint(checkpoint_path: str, model_name: str, config: di
242
 
243
 
244
  # ─────────────────────────────────────────────────────────────────────────────
245
- # INFERENCE FUNCTION - Exactly from Colab SECTION 9
246
  # ─────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
247
  def generate_report(
248
  image_path: str,
249
  model: VisionT5Model,
250
  config: dict
251
  ) -> str:
252
  """
253
- Generate medical report from X-ray image - EXACT implementation from Colab
 
 
 
 
 
 
 
 
 
254
  """
255
  try:
256
  # Preprocess image
257
- image = Image.open(image_path).convert('RGB')
258
- pixel_values = transform(image).unsqueeze(0).to(device)
259
 
260
- # Generate report - using EXACT parameters from Colab
261
  with torch.no_grad():
262
  generated_ids = model.generate_reports(
263
  pixel_values,
@@ -301,7 +345,7 @@ except Exception as e:
301
  PPO_MODEL_PATH = "/content/rlhf_model.pt"
302
  print(f"⚠️ Using local paths instead")
303
 
304
- # Load both models
305
  print("\n" + "="*80)
306
  print("LOADING MODELS")
307
  print("="*80)
@@ -323,7 +367,7 @@ print("\n✓ Both models loaded successfully!")
323
  # ─────────────────────────────────────────────────────────────────────────────
324
  # FASTAPI APP
325
  # ─────────────────────────────────────────────────────────────────────────────
326
- app = FastAPI(title="Medical Report Generation - Matching Colab")
327
 
328
  app.add_middleware(
329
  CORSMiddleware,
@@ -333,46 +377,41 @@ app.add_middleware(
333
  )
334
 
335
 
336
- def preprocess_bytes(file_bytes: bytes) -> torch.Tensor:
337
- """Preprocess image bytes for inference"""
338
- img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
339
- return transform(img).unsqueeze(0).to(device)
340
-
341
-
342
  @app.get("/health")
343
  def health():
344
  return {
345
  "status": "ok",
346
  "device": str(device),
 
347
  "models_loaded": True,
348
- "config": CONFIG
349
  }
350
 
351
 
352
  @app.post("/sft")
353
  async def sft_inference(file: UploadFile = File(...)):
354
  """
355
- SFT model inference - EXACTLY matching Colab behavior
356
  """
357
  try:
358
- # Preprocess image
359
- tensor = preprocess_bytes(await file.read())
 
 
360
 
361
- # Generate report using EXACT Colab parameters
362
- with torch.no_grad():
363
- generated_ids = sft_model.generate_reports(
364
- tensor,
365
- max_length=CONFIG['max_length'],
366
- num_beams=CONFIG['num_beams']
367
- )
368
 
369
- # Decode - EXACTLY as Colab does
370
- report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
371
 
372
- print(f"[SFT] Generated: {report}")
373
 
374
- # Return FULL report without truncation
375
- return {"report": report, "model": "SFT", "config_used": CONFIG}
 
 
 
376
 
377
  except Exception as e:
378
  traceback.print_exc()
@@ -382,27 +421,27 @@ async def sft_inference(file: UploadFile = File(...)):
382
  @app.post("/ppo")
383
  async def ppo_inference(file: UploadFile = File(...)):
384
  """
385
- PPO model inference - EXACTLY matching Colab behavior
386
  """
387
  try:
388
- # Preprocess image
389
- tensor = preprocess_bytes(await file.read())
 
 
390
 
391
- # Generate report using EXACT Colab parameters
392
- with torch.no_grad():
393
- generated_ids = ppo_model.generate_reports(
394
- tensor,
395
- max_length=CONFIG['max_length'],
396
- num_beams=CONFIG['num_beams']
397
- )
398
 
399
- # Decode - EXACTLY as Colab does
400
- report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
401
 
402
- print(f"[PPO] Generated: {report}")
403
 
404
- # Return FULL report without truncation
405
- return {"report": report, "model": "PPO", "config_used": CONFIG}
 
 
 
406
 
407
  except Exception as e:
408
  traceback.print_exc()
@@ -413,28 +452,20 @@ async def ppo_inference(file: UploadFile = File(...)):
413
  async def compare_models(file: UploadFile = File(...)):
414
  """
415
  Generate reports from both models for comparison
 
416
  """
417
  try:
418
- file_bytes = await file.read()
419
- tensor = preprocess_bytes(file_bytes)
 
 
420
 
421
- # SFT Generation
422
- with torch.no_grad():
423
- sft_ids = sft_model.generate_reports(
424
- tensor,
425
- max_length=CONFIG['max_length'],
426
- num_beams=CONFIG['num_beams']
427
- )
428
- sft_report = tokenizer.decode(sft_ids[0], skip_special_tokens=True).strip()
429
 
430
- # PPO Generation
431
- with torch.no_grad():
432
- ppo_ids = ppo_model.generate_reports(
433
- tensor,
434
- max_length=CONFIG['max_length'],
435
- num_beams=CONFIG['num_beams']
436
- )
437
- ppo_report = tokenizer.decode(ppo_ids[0], skip_special_tokens=True).strip()
438
 
439
  print(f"[COMPARE] SFT: {sft_report}")
440
  print(f"[COMPARE] PPO: {ppo_report}")
@@ -442,7 +473,8 @@ async def compare_models(file: UploadFile = File(...)):
442
  return {
443
  "sft_report": sft_report,
444
  "ppo_report": ppo_report,
445
- "config_used": CONFIG
 
446
  }
447
 
448
  except Exception as e:
@@ -453,19 +485,43 @@ async def compare_models(file: UploadFile = File(...)):
453
  }
454
 
455
 
456
- @app.get("/debug_config")
457
- def debug_config():
458
- """Debug endpoint to check configuration"""
 
 
459
  return {
460
- "config": CONFIG,
461
  "device": str(device),
 
 
 
 
 
 
 
 
 
 
462
  "tokenizer": CONFIG['t5_model'],
463
- "image_size": CONFIG['image_size'],
464
- "max_length": CONFIG['max_length'],
465
- "num_beams": CONFIG['num_beams'],
 
 
 
 
 
 
 
 
 
466
  "models_loaded": {
467
  "sft": sft_model is not None,
468
  "ppo": ppo_model is not None
 
 
 
 
469
  }
470
  }
471
 
@@ -481,6 +537,17 @@ if os.path.exists("build"):
481
  else:
482
  print("⚠️ Build directory not found, serving API only")
483
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
  if __name__ == "__main__":
486
  import uvicorn
 
12
  from huggingface_hub import hf_hub_download
13
 
14
  # ─────────────────────────────────────────────────────────────────────────────
15
+ # CONFIGURATION - EXACTLY matching Colab CONFIG from SECTION 4
16
  # ─────────────────────────────────────────────────────────────────────────────
17
+ print("="*80)
18
+ print("INITIALIZING CONFIGURATION")
19
+ print("="*80)
20
+
21
+ # Device setup - EXACTLY as Colab SECTION 3
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ print(f"PyTorch version: {torch.__version__}")
24
+ print(f"CUDA available: {torch.cuda.is_available()}")
25
+ if torch.cuda.is_available():
26
+ print(f"GPU Device: {torch.cuda.get_device_name(0)}")
27
+ torch.cuda.empty_cache()
28
+ print(f"🖥️ Using device: {device}")
29
+
30
+ # Configuration - EXACTLY matching Colab SECTION 4
31
  CONFIG = {
32
+ # Model architecture settings
33
  'coatnet_model': 'coatnet_1_rw_224',
34
  't5_model': 't5-small',
35
  'img_emb_dim': 768,
36
  'train_last_stages': 2,
37
+
38
+ # Image preprocessing
39
  'image_size': 224,
40
+
41
+ # Inference settings
42
  'max_length': 100,
43
  'num_beams': 4,
44
+
45
+ # Device
46
+ 'device': device
47
  }
48
 
49
+ print("\nConfiguration loaded:")
50
+ for key, value in CONFIG.items():
51
+ if key != 'device':
52
+ print(f" {key}: {value}")
 
53
 
54
  # ─────────────────────────────────────────────────────────────────────────────
55
+ # SECTION 6: Model Architecture Definitions - EXACT COPY from Colab
56
  # ─────────────────────────────────────────────────────────────────────────────
57
  print("\n" + "="*80)
58
+ print("DEFINING MODEL ARCHITECTURES")
59
  print("="*80)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # --- Encoder: CoAtNet --- EXACT COPY from Colab SECTION 6
 
 
62
  class CoAtNetEncoder(nn.Module):
63
  def __init__(self, model_name="coatnet_1_rw_224", pretrained=True, train_last_stages=2):
64
  super().__init__()
 
84
  return self.encoder(x)
85
 
86
 
87
+ # --- Vision-T5 Model --- EXACT COPY from Colab SECTION 6
 
 
88
  class VisionT5Model(nn.Module):
89
  def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768):
90
  super().__init__()
 
129
  return outputs
130
 
131
  def generate_reports(self, pixel_values, max_length=100, num_beams=4):
 
 
 
132
  # Extract and project image features
133
  img_feats = self.img_encoder(pixel_values)
134
  img_feats = self.proj(img_feats)
 
139
  inputs_embeds=encoder_hidden_states
140
  )
141
 
142
+ # Generate report using beam search
143
  generated_ids = self.t5.generate(
144
  encoder_outputs=encoder_outputs,
145
  attention_mask=torch.ones(
 
156
  print("✓ Model architecture classes defined")
157
 
158
  # ─────────────────────────────────────────────────────────────────────────────
159
+ # SECTION 7: Load Tokenizer and Image Transform - EXACT COPY from Colab
160
+ # ─────────────────────────────────────────────────────────────────────────────
161
+ print("\n" + "="*80)
162
+ print("LOADING TOKENIZER AND IMAGE TRANSFORM")
163
+ print("="*80)
164
+
165
+ # Load tokenizer
166
+ tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model'])
167
+ print(f"✓ Loaded tokenizer: {CONFIG['t5_model']}")
168
+
169
+ # Define image transform - EXACTLY as Colab SECTION 7
170
+ transform = transforms.Compose([
171
+ transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
172
+ transforms.ToTensor(),
173
+ transforms.Normalize(
174
+ mean=[0.485, 0.456, 0.406],
175
+ std=[0.229, 0.224, 0.225]
176
+ )
177
+ ])
178
+ print(f"✓ Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})")
179
+
180
+ # ─────────────────────────────────────────────────────────────────────────────
181
+ # SECTION 8: Model Loading Functions - EXACT COPY from Colab
182
  # ─────────────────────────────────────────────────────────────────────────────
183
  def load_model_from_checkpoint(checkpoint_path: str, model_name: str, config: dict):
184
  """
185
+ Load VisionT5Model from checkpoint.
186
+ EXACT COPY from Colab SECTION 8
187
+
188
+ Args:
189
+ checkpoint_path: Path to .pt checkpoint file
190
+ model_name: Name for logging (e.g., 'SFT' or 'PPO')
191
+ config: Configuration dictionary
192
+
193
+ Returns:
194
+ Loaded model
195
  """
196
  print(f"\nLoading {model_name} model...")
197
  print(f" Checkpoint: {checkpoint_path}")
 
272
 
273
 
274
  # ─────────────────────────────────────────────────────────────────────────────
275
+ # SECTION 9: Inference Functions - EXACT COPY from Colab
276
  # ─────────────────────────────────────────────────────────────────────────────
277
+ def preprocess_image(image_path: str) -> torch.Tensor:
278
+ """Load and preprocess image. EXACT COPY from Colab SECTION 9"""
279
+ image = Image.open(image_path).convert('RGB')
280
+ return transform(image)
281
+
282
+
283
  def generate_report(
284
  image_path: str,
285
  model: VisionT5Model,
286
  config: dict
287
  ) -> str:
288
  """
289
+ Generate medical report from X-ray image.
290
+ EXACT COPY from Colab SECTION 9
291
+
292
+ Args:
293
+ image_path: Path to X-ray image
294
+ model: VisionT5Model
295
+ config: Configuration dictionary
296
+
297
+ Returns:
298
+ Generated report text
299
  """
300
  try:
301
  # Preprocess image
302
+ pixel_values = preprocess_image(image_path).unsqueeze(0).to(device)
 
303
 
304
+ # Generate report
305
  with torch.no_grad():
306
  generated_ids = model.generate_reports(
307
  pixel_values,
 
345
  PPO_MODEL_PATH = "/content/rlhf_model.pt"
346
  print(f"⚠️ Using local paths instead")
347
 
348
+ # Load both models - EXACTLY as Colab SECTION 8
349
  print("\n" + "="*80)
350
  print("LOADING MODELS")
351
  print("="*80)
 
367
  # ─────────────────────────────────────────────────────────────────────────────
368
  # FASTAPI APP
369
  # ─────────────────────────────────────────────────────────────────────────────
370
+ app = FastAPI(title="Medical Report Generation - Exact Colab Match")
371
 
372
  app.add_middleware(
373
  CORSMiddleware,
 
377
  )
378
 
379
 
 
 
 
 
 
 
380
  @app.get("/health")
381
  def health():
382
  return {
383
  "status": "ok",
384
  "device": str(device),
385
+ "cuda_available": torch.cuda.is_available(),
386
  "models_loaded": True,
387
+ "config": {k: v for k, v in CONFIG.items() if k != 'device'}
388
  }
389
 
390
 
391
  @app.post("/sft")
392
  async def sft_inference(file: UploadFile = File(...)):
393
  """
394
+ SFT model inference - Uses EXACT generate_report() function from Colab SECTION 9
395
  """
396
  try:
397
+ # Save uploaded file temporarily
398
+ temp_path = f"/tmp/{file.filename}"
399
+ with open(temp_path, "wb") as f:
400
+ f.write(await file.read())
401
 
402
+ # Use EXACT generate_report function from Colab
403
+ report = generate_report(temp_path, sft_model, CONFIG)
 
 
 
 
 
404
 
405
+ # Clean up temp file
406
+ os.remove(temp_path)
407
 
408
+ print(f"[SFT] Generated report: {report}")
409
 
410
+ return {
411
+ "report": report,
412
+ "model": "SFT",
413
+ "method": "generate_report() - exact Colab SECTION 9"
414
+ }
415
 
416
  except Exception as e:
417
  traceback.print_exc()
 
421
  @app.post("/ppo")
422
  async def ppo_inference(file: UploadFile = File(...)):
423
  """
424
+ PPO model inference - Uses EXACT generate_report() function from Colab SECTION 9
425
  """
426
  try:
427
+ # Save uploaded file temporarily
428
+ temp_path = f"/tmp/{file.filename}"
429
+ with open(temp_path, "wb") as f:
430
+ f.write(await file.read())
431
 
432
+ # Use EXACT generate_report function from Colab
433
+ report = generate_report(temp_path, ppo_model, CONFIG)
 
 
 
 
 
434
 
435
+ # Clean up temp file
436
+ os.remove(temp_path)
437
 
438
+ print(f"[PPO] Generated report: {report}")
439
 
440
+ return {
441
+ "report": report,
442
+ "model": "PPO",
443
+ "method": "generate_report() - exact Colab SECTION 9"
444
+ }
445
 
446
  except Exception as e:
447
  traceback.print_exc()
 
452
  async def compare_models(file: UploadFile = File(...)):
453
  """
454
  Generate reports from both models for comparison
455
+ Uses EXACT generate_report() function from Colab
456
  """
457
  try:
458
+ # Save uploaded file temporarily
459
+ temp_path = f"/tmp/{file.filename}"
460
+ with open(temp_path, "wb") as f:
461
+ f.write(await file.read())
462
 
463
+ # Use EXACT generate_report function from Colab for both models
464
+ sft_report = generate_report(temp_path, sft_model, CONFIG)
465
+ ppo_report = generate_report(temp_path, ppo_model, CONFIG)
 
 
 
 
 
466
 
467
+ # Clean up temp file
468
+ os.remove(temp_path)
 
 
 
 
 
 
469
 
470
  print(f"[COMPARE] SFT: {sft_report}")
471
  print(f"[COMPARE] PPO: {ppo_report}")
 
473
  return {
474
  "sft_report": sft_report,
475
  "ppo_report": ppo_report,
476
+ "method": "generate_report() - exact Colab SECTION 9",
477
+ "config": {k: v for k, v in CONFIG.items() if k != 'device'}
478
  }
479
 
480
  except Exception as e:
 
485
  }
486
 
487
 
488
+ @app.get("/debug_inference")
489
+ def debug_inference():
490
+ """
491
+ Debug endpoint to verify inference setup matches Colab exactly
492
+ """
493
  return {
 
494
  "device": str(device),
495
+ "cuda_available": torch.cuda.is_available(),
496
+ "config": {
497
+ "coatnet_model": CONFIG['coatnet_model'],
498
+ "t5_model": CONFIG['t5_model'],
499
+ "img_emb_dim": CONFIG['img_emb_dim'],
500
+ "train_last_stages": CONFIG['train_last_stages'],
501
+ "image_size": CONFIG['image_size'],
502
+ "max_length": CONFIG['max_length'],
503
+ "num_beams": CONFIG['num_beams'],
504
+ },
505
  "tokenizer": CONFIG['t5_model'],
506
+ "transform": {
507
+ "resize": f"{CONFIG['image_size']}x{CONFIG['image_size']}",
508
+ "normalize_mean": [0.485, 0.456, 0.406],
509
+ "normalize_std": [0.229, 0.224, 0.225]
510
+ },
511
+ "generation_params": {
512
+ "max_length": CONFIG['max_length'],
513
+ "num_beams": CONFIG['num_beams'],
514
+ "early_stopping": True,
515
+ "no_extra_penalties": "✓ Exactly as Colab"
516
+ },
517
+ "inference_method": "generate_report() from Colab SECTION 9",
518
  "models_loaded": {
519
  "sft": sft_model is not None,
520
  "ppo": ppo_model is not None
521
+ },
522
+ "model_state": {
523
+ "sft_eval_mode": not sft_model.training if sft_model else None,
524
+ "ppo_eval_mode": not ppo_model.training if ppo_model else None
525
  }
526
  }
527
 
 
537
  else:
538
  print("⚠️ Build directory not found, serving API only")
539
 
540
+ print("\n" + "="*80)
541
+ print("SERVER READY - Using EXACT Colab Inference Code")
542
+ print("="*80)
543
+ print("Key points:")
544
+ print(" ✓ Model architecture: VisionT5Model (exact copy from Colab SECTION 6)")
545
+ print(" ✓ Inference method: generate_report() (exact copy from Colab SECTION 9)")
546
+ print(" ✓ Generation params: max_length=100, num_beams=4, early_stopping=True")
547
+ print(" ✓ No extra penalties: NO repetition_penalty, NO no_repeat_ngram_size")
548
+ print(" ✓ Transform: Resize 224x224, Normalize [0.485,0.456,0.406]/[0.229,0.224,0.225]")
549
+ print(" ✓ Device handling: Same as Colab")
550
+ print("="*80)
551
 
552
  if __name__ == "__main__":
553
  import uvicorn