Shree2604 commited on
Commit
f0f2c4a
·
verified ·
1 Parent(s): d8c1e26

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +125 -269
server.py CHANGED
@@ -12,53 +12,49 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer
12
  from huggingface_hub import hf_hub_download
13
 
14
  # ─────────────────────────────────────────────────────────────────────────────
15
- # CONFIGURATION - EXACTLY matching Colab CONFIG from SECTION 4
16
  # ─────────────────────────────────────────────────────────────────────────────
17
- print("="*80)
18
- print("INITIALIZING CONFIGURATION")
19
- print("="*80)
20
-
21
- # Device setup - EXACTLY as Colab SECTION 3
22
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- print(f"PyTorch version: {torch.__version__}")
24
- print(f"CUDA available: {torch.cuda.is_available()}")
25
- if torch.cuda.is_available():
26
- print(f"GPU Device: {torch.cuda.get_device_name(0)}")
27
- torch.cuda.empty_cache()
28
- print(f"🖥️ Using device: {device}")
29
-
30
- # Configuration - EXACTLY matching Colab SECTION 4
31
  CONFIG = {
32
- # Model architecture settings
33
  'coatnet_model': 'coatnet_1_rw_224',
34
  't5_model': 't5-small',
35
  'img_emb_dim': 768,
36
  'train_last_stages': 2,
37
-
38
- # Image preprocessing
39
  'image_size': 224,
40
-
41
- # Inference settings
42
  'max_length': 100,
43
  'num_beams': 4,
44
-
45
- # Device
46
- 'device': device
47
  }
48
 
49
- print("\nConfiguration loaded:")
50
- for key, value in CONFIG.items():
51
- if key != 'device':
52
- print(f" {key}: {value}")
 
53
 
54
  # ─────────────────────────────────────────────────────────────────────────────
55
- # SECTION 6: Model Architecture Definitions - EXACT COPY from Colab
56
  # ─────────────────────────────────────────────────────────────────────────────
57
  print("\n" + "="*80)
58
- print("DEFINING MODEL ARCHITECTURES")
59
  print("="*80)
 
 
60
 
61
- # --- Encoder: CoAtNet --- EXACT COPY from Colab SECTION 6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  class CoAtNetEncoder(nn.Module):
63
  def __init__(self, model_name="coatnet_1_rw_224", pretrained=True, train_last_stages=2):
64
  super().__init__()
@@ -84,7 +80,9 @@ class CoAtNetEncoder(nn.Module):
84
  return self.encoder(x)
85
 
86
 
87
- # --- Vision-T5 Model --- EXACT COPY from Colab SECTION 6
 
 
88
  class VisionT5Model(nn.Module):
89
  def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768):
90
  super().__init__()
@@ -129,6 +127,9 @@ class VisionT5Model(nn.Module):
129
  return outputs
130
 
131
  def generate_reports(self, pixel_values, max_length=100, num_beams=4):
 
 
 
132
  # Extract and project image features
133
  img_feats = self.img_encoder(pixel_values)
134
  img_feats = self.proj(img_feats)
@@ -139,7 +140,7 @@ class VisionT5Model(nn.Module):
139
  inputs_embeds=encoder_hidden_states
140
  )
141
 
142
- # Generate report using beam search
143
  generated_ids = self.t5.generate(
144
  encoder_outputs=encoder_outputs,
145
  attention_mask=torch.ones(
@@ -156,42 +157,11 @@ class VisionT5Model(nn.Module):
156
  print("✓ Model architecture classes defined")
157
 
158
  # ─────────────────────────────────────────────────────────────────────────────
159
- # SECTION 7: Load Tokenizer and Image Transform - EXACT COPY from Colab
160
- # ─────────────────────────────────────────────────────────────────────────────
161
- print("\n" + "="*80)
162
- print("LOADING TOKENIZER AND IMAGE TRANSFORM")
163
- print("="*80)
164
-
165
- # Load tokenizer
166
- tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model'])
167
- print(f"✓ Loaded tokenizer: {CONFIG['t5_model']}")
168
-
169
- # Define image transform - EXACTLY as Colab SECTION 7
170
- transform = transforms.Compose([
171
- transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
172
- transforms.ToTensor(),
173
- transforms.Normalize(
174
- mean=[0.485, 0.456, 0.406],
175
- std=[0.229, 0.224, 0.225]
176
- )
177
- ])
178
- print(f"✓ Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})")
179
-
180
- # ─────────────────────────────────────────────────────────────────────────────
181
- # SECTION 8: Model Loading Functions - EXACT COPY from Colab
182
  # ─────────────────────────────────────────────────────────────────────────────
183
  def load_model_from_checkpoint(checkpoint_path: str, model_name: str, config: dict):
184
  """
185
- Load VisionT5Model from checkpoint.
186
- EXACT COPY from Colab SECTION 8
187
-
188
- Args:
189
- checkpoint_path: Path to .pt checkpoint file
190
- model_name: Name for logging (e.g., 'SFT' or 'PPO')
191
- config: Configuration dictionary
192
-
193
- Returns:
194
- Loaded model
195
  """
196
  print(f"\nLoading {model_name} model...")
197
  print(f" Checkpoint: {checkpoint_path}")
@@ -272,36 +242,22 @@ def load_model_from_checkpoint(checkpoint_path: str, model_name: str, config: di
272
 
273
 
274
  # ─────────────────────────────────────────────────────────────────────────────
275
- # SECTION 9: Inference Functions - EXACT COPY from Colab
276
  # ─────────────────────────────────────────────────────────────────────────────
277
- def preprocess_image(image_path: str) -> torch.Tensor:
278
- """Load and preprocess image. EXACT COPY from Colab SECTION 9"""
279
- image = Image.open(image_path).convert('RGB')
280
- return transform(image)
281
-
282
-
283
  def generate_report(
284
  image_path: str,
285
  model: VisionT5Model,
286
  config: dict
287
  ) -> str:
288
  """
289
- Generate medical report from X-ray image.
290
- EXACT COPY from Colab SECTION 9
291
-
292
- Args:
293
- image_path: Path to X-ray image
294
- model: VisionT5Model
295
- config: Configuration dictionary
296
-
297
- Returns:
298
- Generated report text
299
  """
300
  try:
301
  # Preprocess image
302
- pixel_values = preprocess_image(image_path).unsqueeze(0).to(device)
 
303
 
304
- # Generate report
305
  with torch.no_grad():
306
  generated_ids = model.generate_reports(
307
  pixel_values,
@@ -320,57 +276,34 @@ def generate_report(
320
 
321
 
322
  # ─────────────────────────────────────────────────────────────────────────────
323
- # LOAD MODELS FROM HUGGINGFACE - Shree2604/BioStack
324
  # ─────────────────────────────────────────────────────────────────────────────
325
  print("\n" + "="*80)
326
- print("DOWNLOADING MODELS FROM HUGGINGFACE")
327
  print("="*80)
328
 
329
- # Hugging Face repository
330
- HF_REPO = "Shree2604/BioStack"
331
-
332
  # Download model files from Hugging Face
333
  try:
334
- print(f"📦 Downloading from repository: {HF_REPO}")
335
- print("This may take a few minutes on first run...\n")
336
-
337
- # Download SFT model
338
- print("1️⃣ Downloading SFT model (best_model.pt)...")
339
  SFT_MODEL_PATH = hf_hub_download(
340
- repo_id=HF_REPO,
341
  filename="best_model.pt"
342
  )
343
- print(f" ✓ SFT model downloaded: {SFT_MODEL_PATH}")
344
-
345
- # Download Reward model
346
- print("\n2️⃣ Downloading Reward model (reward_model.pt)...")
347
- REWARD_MODEL_PATH = hf_hub_download(
348
- repo_id=HF_REPO,
349
- filename="reward_model.pt"
350
- )
351
- print(f" ✓ Reward model downloaded: {REWARD_MODEL_PATH}")
352
-
353
- # Download PPO model
354
- print("\n3️⃣ Downloading PPO model (rlhf_model.pt)...")
355
  PPO_MODEL_PATH = hf_hub_download(
356
- repo_id=HF_REPO,
357
  filename="rlhf_model.pt"
358
  )
359
- print(f" PPO model downloaded: {PPO_MODEL_PATH}")
360
-
361
- print("\n✅ All models downloaded successfully!")
362
-
363
  except Exception as e:
364
- print(f"\n❌ Error downloading models from Hugging Face: {e}")
365
- print("Please check:")
366
- print(f" - Repository exists: https://huggingface.co/{HF_REPO}")
367
- print(" - Files exist: best_model.pt, reward_model.pt, rlhf_model.pt")
368
- print(" - You have internet connection")
369
- raise
370
-
371
- # Load both models - EXACTLY as Colab SECTION 8
372
  print("\n" + "="*80)
373
- print("LOADING MODELS INTO MEMORY")
374
  print("="*80)
375
 
376
  sft_model = load_model_from_checkpoint(
@@ -385,17 +318,12 @@ ppo_model = load_model_from_checkpoint(
385
  CONFIG
386
  )
387
 
388
- print("\n Both models loaded successfully!")
389
- print("="*80)
390
 
391
  # ─────────────────────────────────────────────────────────────────────────────
392
  # FASTAPI APP
393
  # ─────────────────────────────────────────────────────────────────────────────
394
- app = FastAPI(
395
- title="BioStack Medical Report Generation",
396
- description="Medical X-ray report generation using SFT and PPO models from Shree2604/BioStack",
397
- version="1.0.0"
398
- )
399
 
400
  app.add_middleware(
401
  CORSMiddleware,
@@ -405,23 +333,10 @@ app.add_middleware(
405
  )
406
 
407
 
408
- @app.get("/")
409
- def root():
410
- return {
411
- "message": "BioStack Medical Report Generation API",
412
- "repository": "Shree2604/BioStack",
413
- "models": {
414
- "sft": "best_model.pt",
415
- "ppo": "rlhf_model.pt",
416
- "reward": "reward_model.pt"
417
- },
418
- "endpoints": {
419
- "health": "GET /health - Check API status",
420
- "sft": "POST /sft - Generate report using SFT model",
421
- "ppo": "POST /ppo - Generate report using PPO model",
422
- "compare": "POST /compare - Compare both models"
423
- }
424
- }
425
 
426
 
427
  @app.get("/health")
@@ -429,111 +344,97 @@ def health():
429
  return {
430
  "status": "ok",
431
  "device": str(device),
432
- "cuda_available": torch.cuda.is_available(),
433
- "models_loaded": {
434
- "sft": sft_model is not None,
435
- "ppo": ppo_model is not None
436
- },
437
- "repository": HF_REPO,
438
- "model_files": {
439
- "sft": os.path.basename(SFT_MODEL_PATH),
440
- "ppo": os.path.basename(PPO_MODEL_PATH)
441
- }
442
  }
443
 
444
 
445
  @app.post("/sft")
446
  async def sft_inference(file: UploadFile = File(...)):
447
  """
448
- SFT model inference - Uses EXACT generate_report() function from Colab SECTION 9
449
-
450
- Model: best_model.pt from Shree2604/BioStack
451
  """
452
  try:
453
- # Save uploaded file temporarily
454
- temp_path = f"/tmp/{file.filename}"
455
- with open(temp_path, "wb") as f:
456
- f.write(await file.read())
457
 
458
- # Use EXACT generate_report function from Colab
459
- report = generate_report(temp_path, sft_model, CONFIG)
 
 
 
 
 
460
 
461
- # Clean up temp file
462
- os.remove(temp_path)
463
 
464
- print(f"[SFT] Generated report: {report}")
465
 
466
- return {
467
- "report": report,
468
- "model": "SFT",
469
- "source": "best_model.pt",
470
- "repository": HF_REPO
471
- }
472
 
473
  except Exception as e:
474
  traceback.print_exc()
475
- return {
476
- "report": f"ERROR: {str(e)}",
477
- "model": "SFT"
478
- }
479
 
480
 
481
  @app.post("/ppo")
482
  async def ppo_inference(file: UploadFile = File(...)):
483
  """
484
- PPO model inference - Uses EXACT generate_report() function from Colab SECTION 9
485
-
486
- Model: rlhf_model.pt from Shree2604/BioStack
487
  """
488
  try:
489
- # Save uploaded file temporarily
490
- temp_path = f"/tmp/{file.filename}"
491
- with open(temp_path, "wb") as f:
492
- f.write(await file.read())
493
 
494
- # Use EXACT generate_report function from Colab
495
- report = generate_report(temp_path, ppo_model, CONFIG)
 
 
 
 
 
496
 
497
- # Clean up temp file
498
- os.remove(temp_path)
499
 
500
- print(f"[PPO] Generated report: {report}")
501
 
502
- return {
503
- "report": report,
504
- "model": "PPO",
505
- "source": "rlhf_model.pt",
506
- "repository": HF_REPO
507
- }
508
 
509
  except Exception as e:
510
  traceback.print_exc()
511
- return {
512
- "report": f"ERROR: {str(e)}",
513
- "model": "PPO"
514
- }
515
 
516
 
517
  @app.post("/compare")
518
  async def compare_models(file: UploadFile = File(...)):
519
  """
520
  Generate reports from both models for comparison
521
- Uses EXACT generate_report() function from Colab
522
-
523
- Models: best_model.pt and rlhf_model.pt from Shree2604/BioStack
524
  """
525
  try:
526
- # Save uploaded file temporarily
527
- temp_path = f"/tmp/{file.filename}"
528
- with open(temp_path, "wb") as f:
529
- f.write(await file.read())
530
 
531
- # Use EXACT generate_report function from Colab for both models
532
- sft_report = generate_report(temp_path, sft_model, CONFIG)
533
- ppo_report = generate_report(temp_path, ppo_model, CONFIG)
 
 
 
 
 
534
 
535
- # Clean up temp file
536
- os.remove(temp_path)
 
 
 
 
 
 
537
 
538
  print(f"[COMPARE] SFT: {sft_report}")
539
  print(f"[COMPARE] PPO: {ppo_report}")
@@ -541,11 +442,7 @@ async def compare_models(file: UploadFile = File(...)):
541
  return {
542
  "sft_report": sft_report,
543
  "ppo_report": ppo_report,
544
- "models": {
545
- "sft": "best_model.pt",
546
- "ppo": "rlhf_model.pt"
547
- },
548
- "repository": HF_REPO
549
  }
550
 
551
  except Exception as e:
@@ -556,45 +453,19 @@ async def compare_models(file: UploadFile = File(...)):
556
  }
557
 
558
 
559
- @app.get("/model_info")
560
- def model_info():
561
- """
562
- Get detailed information about loaded models
563
- """
564
  return {
565
- "repository": HF_REPO,
566
- "repository_url": f"https://huggingface.co/{HF_REPO}",
567
- "models": {
568
- "sft": {
569
- "filename": "best_model.pt",
570
- "url": f"https://huggingface.co/{HF_REPO}/blob/main/best_model.pt",
571
- "local_path": SFT_MODEL_PATH,
572
- "loaded": sft_model is not None,
573
- "in_eval_mode": not sft_model.training if sft_model else None
574
- },
575
- "ppo": {
576
- "filename": "rlhf_model.pt",
577
- "url": f"https://huggingface.co/{HF_REPO}/blob/main/rlhf_model.pt",
578
- "local_path": PPO_MODEL_PATH,
579
- "loaded": ppo_model is not None,
580
- "in_eval_mode": not ppo_model.training if ppo_model else None
581
- },
582
- "reward": {
583
- "filename": "reward_model.pt",
584
- "url": f"https://huggingface.co/{HF_REPO}/blob/main/reward_model.pt",
585
- "local_path": REWARD_MODEL_PATH,
586
- "note": "Downloaded but not loaded in this API"
587
- }
588
- },
589
- "architecture": {
590
- "vision_encoder": CONFIG['coatnet_model'],
591
- "text_model": CONFIG['t5_model'],
592
- "image_embedding_dim": CONFIG['img_emb_dim']
593
- },
594
- "inference_config": {
595
- "max_length": CONFIG['max_length'],
596
- "num_beams": CONFIG['num_beams'],
597
- "image_size": CONFIG['image_size']
598
  }
599
  }
600
 
@@ -610,21 +481,6 @@ if os.path.exists("build"):
610
  else:
611
  print("⚠️ Build directory not found, serving API only")
612
 
613
- print("\n" + "="*80)
614
- print("🚀 SERVER READY")
615
- print("="*80)
616
- print(f"Repository: {HF_REPO}")
617
- print("Models loaded:")
618
- print(f" ✓ SFT: best_model.pt")
619
- print(f" ✓ PPO: rlhf_model.pt")
620
- print("\nEndpoints:")
621
- print(" GET / - API info")
622
- print(" GET /health - Health check")
623
- print(" GET /model_info - Model details")
624
- print(" POST /sft - SFT inference")
625
- print(" POST /ppo - PPO inference")
626
- print(" POST /compare - Compare both models")
627
- print("="*80)
628
 
629
  if __name__ == "__main__":
630
  import uvicorn
 
12
  from huggingface_hub import hf_hub_download
13
 
14
  # ─────────────────────────────────────────────────────────────────────────────
15
+ # CONFIGURATION - Matching Colab Notebook Exactly
16
  # ─────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  CONFIG = {
 
18
  'coatnet_model': 'coatnet_1_rw_224',
19
  't5_model': 't5-small',
20
  'img_emb_dim': 768,
21
  'train_last_stages': 2,
 
 
22
  'image_size': 224,
 
 
23
  'max_length': 100,
24
  'num_beams': 4,
 
 
 
25
  }
26
 
27
+ # ─────────────────────────────────────────────────────────────────────────────
28
+ # DEVICE
29
+ # ─────────────────────────────────────────────────────────────────────────────
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ print(f"🖥️ Using device: {device}")
32
 
33
  # ─────────────────────────────────────────────────────────────────────────────
34
+ # LOAD TOKENIZER - Matching Colab
35
  # ─────────────────────────────────────────────────────────────────────────────
36
  print("\n" + "="*80)
37
+ print("LOADING TOKENIZER")
38
  print("="*80)
39
+ tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model'])
40
+ print(f"✓ Loaded tokenizer: {CONFIG['t5_model']}")
41
 
42
+ # ─────────────────────────────────────────────────────────────────────────────
43
+ # IMAGE TRANSFORM - Matching Colab Exactly
44
+ # ─────────────────────────────────────────────────────────────────────────────
45
+ transform = transforms.Compose([
46
+ transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize(
49
+ mean=[0.485, 0.456, 0.406],
50
+ std=[0.229, 0.224, 0.225]
51
+ )
52
+ ])
53
+ print(f"✓ Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})")
54
+
55
+ # ─────────────────────────────────────────────────────────────────────────────
56
+ # ARCHITECTURE 1: CoAtNetEncoder - Exactly from Colab SECTION 6
57
+ # ─────────────────────────────────────────────────────────────────────────────
58
  class CoAtNetEncoder(nn.Module):
59
  def __init__(self, model_name="coatnet_1_rw_224", pretrained=True, train_last_stages=2):
60
  super().__init__()
 
80
  return self.encoder(x)
81
 
82
 
83
+ # ─────────────────────────────────────────────────────────────────────────────
84
+ # ARCHITECTURE 2: VisionT5Model - Exactly from Colab SECTION 6
85
+ # ─────────────────────────────────────────────────────────────────────────────
86
  class VisionT5Model(nn.Module):
87
  def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768):
88
  super().__init__()
 
127
  return outputs
128
 
129
  def generate_reports(self, pixel_values, max_length=100, num_beams=4):
130
+ """
131
+ Generate reports - EXACTLY matching Colab SECTION 6
132
+ """
133
  # Extract and project image features
134
  img_feats = self.img_encoder(pixel_values)
135
  img_feats = self.proj(img_feats)
 
140
  inputs_embeds=encoder_hidden_states
141
  )
142
 
143
+ # Generate report using beam search - EXACT parameters from Colab
144
  generated_ids = self.t5.generate(
145
  encoder_outputs=encoder_outputs,
146
  attention_mask=torch.ones(
 
157
  print("✓ Model architecture classes defined")
158
 
159
  # ─────────────────────────────────────────────────────────────────────────────
160
+ # MODEL LOADING FUNCTION - Exactly from Colab SECTION 8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  # ─────────────────────────────────────────────────────────────────────────────
162
  def load_model_from_checkpoint(checkpoint_path: str, model_name: str, config: dict):
163
  """
164
+ Load VisionT5Model from checkpoint - EXACT implementation from Colab
 
 
 
 
 
 
 
 
 
165
  """
166
  print(f"\nLoading {model_name} model...")
167
  print(f" Checkpoint: {checkpoint_path}")
 
242
 
243
 
244
  # ─────────────────────────────────────────────────────────────────────────────
245
+ # INFERENCE FUNCTION - Exactly from Colab SECTION 9
246
  # ─────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
247
  def generate_report(
248
  image_path: str,
249
  model: VisionT5Model,
250
  config: dict
251
  ) -> str:
252
  """
253
+ Generate medical report from X-ray image - EXACT implementation from Colab
 
 
 
 
 
 
 
 
 
254
  """
255
  try:
256
  # Preprocess image
257
+ image = Image.open(image_path).convert('RGB')
258
+ pixel_values = transform(image).unsqueeze(0).to(device)
259
 
260
+ # Generate report - using EXACT parameters from Colab
261
  with torch.no_grad():
262
  generated_ids = model.generate_reports(
263
  pixel_values,
 
276
 
277
 
278
  # ─────────────────────────────────────────────────────────────────────────────
279
+ # LOAD MODELS FROM HUGGINGFACE
280
  # ─────────────────────────────────────────────────────────────────────────────
281
  print("\n" + "="*80)
282
+ print("LOADING MODELS FROM HUGGINGFACE")
283
  print("="*80)
284
 
 
 
 
285
  # Download model files from Hugging Face
286
  try:
 
 
 
 
 
287
  SFT_MODEL_PATH = hf_hub_download(
288
+ repo_id="vinaykumarhs2020/RLHF_radiology_model",
289
  filename="best_model.pt"
290
  )
 
 
 
 
 
 
 
 
 
 
 
 
291
  PPO_MODEL_PATH = hf_hub_download(
292
+ repo_id="vinaykumarhs2020/RLHF_radiology_model",
293
  filename="rlhf_model.pt"
294
  )
295
+ print(f"✓ Downloaded SFT model: {SFT_MODEL_PATH}")
296
+ print(f"✓ Downloaded PPO model: {PPO_MODEL_PATH}")
 
 
297
  except Exception as e:
298
+ print(f"❌ Error downloading models: {e}")
299
+ # Fallback to local paths if downloads fail
300
+ SFT_MODEL_PATH = "/content/best_model.pt"
301
+ PPO_MODEL_PATH = "/content/rlhf_model.pt"
302
+ print(f"⚠️ Using local paths instead")
303
+
304
+ # Load both models
 
305
  print("\n" + "="*80)
306
+ print("LOADING MODELS")
307
  print("="*80)
308
 
309
  sft_model = load_model_from_checkpoint(
 
318
  CONFIG
319
  )
320
 
321
+ print("\n Both models loaded successfully!")
 
322
 
323
  # ─────────────────────────────────────────────────────────────────────────────
324
  # FASTAPI APP
325
  # ─────────────────────────────────────────────────────────────────────────────
326
+ app = FastAPI(title="Medical Report Generation - Matching Colab")
 
 
 
 
327
 
328
  app.add_middleware(
329
  CORSMiddleware,
 
333
  )
334
 
335
 
336
+ def preprocess_bytes(file_bytes: bytes) -> torch.Tensor:
337
+ """Preprocess image bytes for inference"""
338
+ img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
339
+ return transform(img).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
 
342
  @app.get("/health")
 
344
  return {
345
  "status": "ok",
346
  "device": str(device),
347
+ "models_loaded": True,
348
+ "config": CONFIG
 
 
 
 
 
 
 
 
349
  }
350
 
351
 
352
  @app.post("/sft")
353
  async def sft_inference(file: UploadFile = File(...)):
354
  """
355
+ SFT model inference - EXACTLY matching Colab behavior
 
 
356
  """
357
  try:
358
+ # Preprocess image
359
+ tensor = preprocess_bytes(await file.read())
 
 
360
 
361
+ # Generate report using EXACT Colab parameters
362
+ with torch.no_grad():
363
+ generated_ids = sft_model.generate_reports(
364
+ tensor,
365
+ max_length=CONFIG['max_length'],
366
+ num_beams=CONFIG['num_beams']
367
+ )
368
 
369
+ # Decode - EXACTLY as Colab does
370
+ report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
371
 
372
+ print(f"[SFT] Generated: {report}")
373
 
374
+ # Return FULL report without truncation
375
+ return {"report": report, "model": "SFT", "config_used": CONFIG}
 
 
 
 
376
 
377
  except Exception as e:
378
  traceback.print_exc()
379
+ return {"report": f"ERROR: {str(e)}", "model": "SFT"}
 
 
 
380
 
381
 
382
  @app.post("/ppo")
383
  async def ppo_inference(file: UploadFile = File(...)):
384
  """
385
+ PPO model inference - EXACTLY matching Colab behavior
 
 
386
  """
387
  try:
388
+ # Preprocess image
389
+ tensor = preprocess_bytes(await file.read())
 
 
390
 
391
+ # Generate report using EXACT Colab parameters
392
+ with torch.no_grad():
393
+ generated_ids = ppo_model.generate_reports(
394
+ tensor,
395
+ max_length=CONFIG['max_length'],
396
+ num_beams=CONFIG['num_beams']
397
+ )
398
 
399
+ # Decode - EXACTLY as Colab does
400
+ report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
401
 
402
+ print(f"[PPO] Generated: {report}")
403
 
404
+ # Return FULL report without truncation
405
+ return {"report": report, "model": "PPO", "config_used": CONFIG}
 
 
 
 
406
 
407
  except Exception as e:
408
  traceback.print_exc()
409
+ return {"report": f"ERROR: {str(e)}", "model": "PPO"}
 
 
 
410
 
411
 
412
  @app.post("/compare")
413
  async def compare_models(file: UploadFile = File(...)):
414
  """
415
  Generate reports from both models for comparison
 
 
 
416
  """
417
  try:
418
+ file_bytes = await file.read()
419
+ tensor = preprocess_bytes(file_bytes)
 
 
420
 
421
+ # SFT Generation
422
+ with torch.no_grad():
423
+ sft_ids = sft_model.generate_reports(
424
+ tensor,
425
+ max_length=CONFIG['max_length'],
426
+ num_beams=CONFIG['num_beams']
427
+ )
428
+ sft_report = tokenizer.decode(sft_ids[0], skip_special_tokens=True).strip()
429
 
430
+ # PPO Generation
431
+ with torch.no_grad():
432
+ ppo_ids = ppo_model.generate_reports(
433
+ tensor,
434
+ max_length=CONFIG['max_length'],
435
+ num_beams=CONFIG['num_beams']
436
+ )
437
+ ppo_report = tokenizer.decode(ppo_ids[0], skip_special_tokens=True).strip()
438
 
439
  print(f"[COMPARE] SFT: {sft_report}")
440
  print(f"[COMPARE] PPO: {ppo_report}")
 
442
  return {
443
  "sft_report": sft_report,
444
  "ppo_report": ppo_report,
445
+ "config_used": CONFIG
 
 
 
 
446
  }
447
 
448
  except Exception as e:
 
453
  }
454
 
455
 
456
+ @app.get("/debug_config")
457
+ def debug_config():
458
+ """Debug endpoint to check configuration"""
 
 
459
  return {
460
+ "config": CONFIG,
461
+ "device": str(device),
462
+ "tokenizer": CONFIG['t5_model'],
463
+ "image_size": CONFIG['image_size'],
464
+ "max_length": CONFIG['max_length'],
465
+ "num_beams": CONFIG['num_beams'],
466
+ "models_loaded": {
467
+ "sft": sft_model is not None,
468
+ "ppo": ppo_model is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  }
470
  }
471
 
 
481
  else:
482
  print("⚠️ Build directory not found, serving API only")
483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
  if __name__ == "__main__":
486
  import uvicorn