Shree2604 commited on
Commit
02dca55
·
verified ·
1 Parent(s): b173385

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +60 -137
server.py CHANGED
@@ -320,57 +320,34 @@ def generate_report(
320
 
321
 
322
  # ─────────────────────────────────────────────────────────────────────────────
323
- # LOAD MODELS FROM HUGGINGFACE - Shree2604/BioStack
324
  # ─────────────────────────────────────────────────────────────────────────────
325
  print("\n" + "="*80)
326
- print("DOWNLOADING MODELS FROM HUGGINGFACE")
327
  print("="*80)
328
 
329
- # Hugging Face repository
330
- HF_REPO = "Shree2604/BioStack"
331
-
332
  # Download model files from Hugging Face
333
  try:
334
- print(f"📦 Downloading from repository: {HF_REPO}")
335
- print("This may take a few minutes on first run...\n")
336
-
337
- # Download SFT model
338
- print("1️⃣ Downloading SFT model (best_model.pt)...")
339
  SFT_MODEL_PATH = hf_hub_download(
340
- repo_id=HF_REPO,
341
  filename="best_model.pt"
342
  )
343
- print(f" ✓ SFT model downloaded: {SFT_MODEL_PATH}")
344
-
345
- # Download Reward model
346
- print("\n2️⃣ Downloading Reward model (reward_model.pt)...")
347
- REWARD_MODEL_PATH = hf_hub_download(
348
- repo_id=HF_REPO,
349
- filename="reward_model.pt"
350
- )
351
- print(f" ✓ Reward model downloaded: {REWARD_MODEL_PATH}")
352
-
353
- # Download PPO model
354
- print("\n3️⃣ Downloading PPO model (rlhf_model.pt)...")
355
  PPO_MODEL_PATH = hf_hub_download(
356
- repo_id=HF_REPO,
357
  filename="rlhf_model.pt"
358
  )
359
- print(f" PPO model downloaded: {PPO_MODEL_PATH}")
360
-
361
- print("\n✅ All models downloaded successfully!")
362
-
363
  except Exception as e:
364
- print(f"\n❌ Error downloading models from Hugging Face: {e}")
365
- print("Please check:")
366
- print(f" - Repository exists: https://huggingface.co/{HF_REPO}")
367
- print(" - Files exist: best_model.pt, reward_model.pt, rlhf_model.pt")
368
- print(" - You have internet connection")
369
- raise
370
 
371
  # Load both models - EXACTLY as Colab SECTION 8
372
  print("\n" + "="*80)
373
- print("LOADING MODELS INTO MEMORY")
374
  print("="*80)
375
 
376
  sft_model = load_model_from_checkpoint(
@@ -385,17 +362,12 @@ ppo_model = load_model_from_checkpoint(
385
  CONFIG
386
  )
387
 
388
- print("\n Both models loaded successfully!")
389
- print("="*80)
390
 
391
  # ─────────────────────────────────────────────────────────────────────────────
392
  # FASTAPI APP
393
  # ─────────────────────────────────────────────────────────────────────────────
394
- app = FastAPI(
395
- title="BioStack Medical Report Generation",
396
- description="Medical X-ray report generation using SFT and PPO models from Shree2604/BioStack",
397
- version="1.0.0"
398
- )
399
 
400
  app.add_middleware(
401
  CORSMiddleware,
@@ -405,40 +377,14 @@ app.add_middleware(
405
  )
406
 
407
 
408
- @app.get("/")
409
- def root():
410
- return {
411
- "message": "BioStack Medical Report Generation API",
412
- "repository": "Shree2604/BioStack",
413
- "models": {
414
- "sft": "best_model.pt",
415
- "ppo": "rlhf_model.pt",
416
- "reward": "reward_model.pt"
417
- },
418
- "endpoints": {
419
- "health": "GET /health - Check API status",
420
- "sft": "POST /sft - Generate report using SFT model",
421
- "ppo": "POST /ppo - Generate report using PPO model",
422
- "compare": "POST /compare - Compare both models"
423
- }
424
- }
425
-
426
-
427
  @app.get("/health")
428
  def health():
429
  return {
430
  "status": "ok",
431
  "device": str(device),
432
  "cuda_available": torch.cuda.is_available(),
433
- "models_loaded": {
434
- "sft": sft_model is not None,
435
- "ppo": ppo_model is not None
436
- },
437
- "repository": HF_REPO,
438
- "model_files": {
439
- "sft": os.path.basename(SFT_MODEL_PATH),
440
- "ppo": os.path.basename(PPO_MODEL_PATH)
441
- }
442
  }
443
 
444
 
@@ -446,8 +392,6 @@ def health():
446
  async def sft_inference(file: UploadFile = File(...)):
447
  """
448
  SFT model inference - Uses EXACT generate_report() function from Colab SECTION 9
449
-
450
- Model: best_model.pt from Shree2604/BioStack
451
  """
452
  try:
453
  # Save uploaded file temporarily
@@ -466,24 +410,18 @@ async def sft_inference(file: UploadFile = File(...)):
466
  return {
467
  "report": report,
468
  "model": "SFT",
469
- "source": "best_model.pt",
470
- "repository": HF_REPO
471
  }
472
 
473
  except Exception as e:
474
  traceback.print_exc()
475
- return {
476
- "report": f"ERROR: {str(e)}",
477
- "model": "SFT"
478
- }
479
 
480
 
481
  @app.post("/ppo")
482
  async def ppo_inference(file: UploadFile = File(...)):
483
  """
484
  PPO model inference - Uses EXACT generate_report() function from Colab SECTION 9
485
-
486
- Model: rlhf_model.pt from Shree2604/BioStack
487
  """
488
  try:
489
  # Save uploaded file temporarily
@@ -502,16 +440,12 @@ async def ppo_inference(file: UploadFile = File(...)):
502
  return {
503
  "report": report,
504
  "model": "PPO",
505
- "source": "rlhf_model.pt",
506
- "repository": HF_REPO
507
  }
508
 
509
  except Exception as e:
510
  traceback.print_exc()
511
- return {
512
- "report": f"ERROR: {str(e)}",
513
- "model": "PPO"
514
- }
515
 
516
 
517
  @app.post("/compare")
@@ -519,8 +453,6 @@ async def compare_models(file: UploadFile = File(...)):
519
  """
520
  Generate reports from both models for comparison
521
  Uses EXACT generate_report() function from Colab
522
-
523
- Models: best_model.pt and rlhf_model.pt from Shree2604/BioStack
524
  """
525
  try:
526
  # Save uploaded file temporarily
@@ -541,11 +473,8 @@ async def compare_models(file: UploadFile = File(...)):
541
  return {
542
  "sft_report": sft_report,
543
  "ppo_report": ppo_report,
544
- "models": {
545
- "sft": "best_model.pt",
546
- "ppo": "rlhf_model.pt"
547
- },
548
- "repository": HF_REPO
549
  }
550
 
551
  except Exception as e:
@@ -556,45 +485,43 @@ async def compare_models(file: UploadFile = File(...)):
556
  }
557
 
558
 
559
- @app.get("/model_info")
560
- def model_info():
561
  """
562
- Get detailed information about loaded models
563
  """
564
  return {
565
- "repository": HF_REPO,
566
- "repository_url": f"https://huggingface.co/{HF_REPO}",
567
- "models": {
568
- "sft": {
569
- "filename": "best_model.pt",
570
- "url": f"https://huggingface.co/{HF_REPO}/blob/main/best_model.pt",
571
- "local_path": SFT_MODEL_PATH,
572
- "loaded": sft_model is not None,
573
- "in_eval_mode": not sft_model.training if sft_model else None
574
- },
575
- "ppo": {
576
- "filename": "rlhf_model.pt",
577
- "url": f"https://huggingface.co/{HF_REPO}/blob/main/rlhf_model.pt",
578
- "local_path": PPO_MODEL_PATH,
579
- "loaded": ppo_model is not None,
580
- "in_eval_mode": not ppo_model.training if ppo_model else None
581
- },
582
- "reward": {
583
- "filename": "reward_model.pt",
584
- "url": f"https://huggingface.co/{HF_REPO}/blob/main/reward_model.pt",
585
- "local_path": REWARD_MODEL_PATH,
586
- "note": "Downloaded but not loaded in this API"
587
- }
588
  },
589
- "architecture": {
590
- "vision_encoder": CONFIG['coatnet_model'],
591
- "text_model": CONFIG['t5_model'],
592
- "image_embedding_dim": CONFIG['img_emb_dim']
 
593
  },
594
- "inference_config": {
595
  "max_length": CONFIG['max_length'],
596
  "num_beams": CONFIG['num_beams'],
597
- "image_size": CONFIG['image_size']
 
 
 
 
 
 
 
 
 
 
598
  }
599
  }
600
 
@@ -611,19 +538,15 @@ else:
611
  print("⚠️ Build directory not found, serving API only")
612
 
613
  print("\n" + "="*80)
614
- print("🚀 SERVER READY")
615
  print("="*80)
616
- print(f"Repository: {HF_REPO}")
617
- print("Models loaded:")
618
- print(f" ✓ SFT: best_model.pt")
619
- print(f" ✓ PPO: rlhf_model.pt")
620
- print("\nEndpoints:")
621
- print(" GET / - API info")
622
- print(" GET /health - Health check")
623
- print(" GET /model_info - Model details")
624
- print(" POST /sft - SFT inference")
625
- print(" POST /ppo - PPO inference")
626
- print(" POST /compare - Compare both models")
627
  print("="*80)
628
 
629
  if __name__ == "__main__":
 
320
 
321
 
322
  # ─────────────────────────────────────────────────────────────────────────────
323
+ # LOAD MODELS FROM HUGGINGFACE
324
  # ─────────────────────────────────────────────────────────────────────────────
325
  print("\n" + "="*80)
326
+ print("LOADING MODELS FROM HUGGINGFACE")
327
  print("="*80)
328
 
 
 
 
329
  # Download model files from Hugging Face
330
  try:
 
 
 
 
 
331
  SFT_MODEL_PATH = hf_hub_download(
332
+ repo_id="vinaykumarhs2020/RLHF_radiology_model",
333
  filename="best_model.pt"
334
  )
 
 
 
 
 
 
 
 
 
 
 
 
335
  PPO_MODEL_PATH = hf_hub_download(
336
+ repo_id="vinaykumarhs2020/RLHF_radiology_model",
337
  filename="rlhf_model.pt"
338
  )
339
+ print(f"✓ Downloaded SFT model: {SFT_MODEL_PATH}")
340
+ print(f"✓ Downloaded PPO model: {PPO_MODEL_PATH}")
 
 
341
  except Exception as e:
342
+ print(f"❌ Error downloading models: {e}")
343
+ # Fallback to local paths if downloads fail
344
+ SFT_MODEL_PATH = "/content/best_model.pt"
345
+ PPO_MODEL_PATH = "/content/rlhf_model.pt"
346
+ print(f"⚠️ Using local paths instead")
 
347
 
348
  # Load both models - EXACTLY as Colab SECTION 8
349
  print("\n" + "="*80)
350
+ print("LOADING MODELS")
351
  print("="*80)
352
 
353
  sft_model = load_model_from_checkpoint(
 
362
  CONFIG
363
  )
364
 
365
+ print("\n Both models loaded successfully!")
 
366
 
367
  # ─────────────────────────────────────────────────────────────────────────────
368
  # FASTAPI APP
369
  # ─────────────────────────────────────────────────────────────────────────────
370
+ app = FastAPI(title="Medical Report Generation - Exact Colab Match")
 
 
 
 
371
 
372
  app.add_middleware(
373
  CORSMiddleware,
 
377
  )
378
 
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  @app.get("/health")
381
  def health():
382
  return {
383
  "status": "ok",
384
  "device": str(device),
385
  "cuda_available": torch.cuda.is_available(),
386
+ "models_loaded": True,
387
+ "config": {k: v for k, v in CONFIG.items() if k != 'device'}
 
 
 
 
 
 
 
388
  }
389
 
390
 
 
392
  async def sft_inference(file: UploadFile = File(...)):
393
  """
394
  SFT model inference - Uses EXACT generate_report() function from Colab SECTION 9
 
 
395
  """
396
  try:
397
  # Save uploaded file temporarily
 
410
  return {
411
  "report": report,
412
  "model": "SFT",
413
+ "method": "generate_report() - exact Colab SECTION 9"
 
414
  }
415
 
416
  except Exception as e:
417
  traceback.print_exc()
418
+ return {"report": f"ERROR: {str(e)}", "model": "SFT"}
 
 
 
419
 
420
 
421
  @app.post("/ppo")
422
  async def ppo_inference(file: UploadFile = File(...)):
423
  """
424
  PPO model inference - Uses EXACT generate_report() function from Colab SECTION 9
 
 
425
  """
426
  try:
427
  # Save uploaded file temporarily
 
440
  return {
441
  "report": report,
442
  "model": "PPO",
443
+ "method": "generate_report() - exact Colab SECTION 9"
 
444
  }
445
 
446
  except Exception as e:
447
  traceback.print_exc()
448
+ return {"report": f"ERROR: {str(e)}", "model": "PPO"}
 
 
 
449
 
450
 
451
  @app.post("/compare")
 
453
  """
454
  Generate reports from both models for comparison
455
  Uses EXACT generate_report() function from Colab
 
 
456
  """
457
  try:
458
  # Save uploaded file temporarily
 
473
  return {
474
  "sft_report": sft_report,
475
  "ppo_report": ppo_report,
476
+ "method": "generate_report() - exact Colab SECTION 9",
477
+ "config": {k: v for k, v in CONFIG.items() if k != 'device'}
 
 
 
478
  }
479
 
480
  except Exception as e:
 
485
  }
486
 
487
 
488
+ @app.get("/debug_inference")
489
+ def debug_inference():
490
  """
491
+ Debug endpoint to verify inference setup matches Colab exactly
492
  """
493
  return {
494
+ "device": str(device),
495
+ "cuda_available": torch.cuda.is_available(),
496
+ "config": {
497
+ "coatnet_model": CONFIG['coatnet_model'],
498
+ "t5_model": CONFIG['t5_model'],
499
+ "img_emb_dim": CONFIG['img_emb_dim'],
500
+ "train_last_stages": CONFIG['train_last_stages'],
501
+ "image_size": CONFIG['image_size'],
502
+ "max_length": CONFIG['max_length'],
503
+ "num_beams": CONFIG['num_beams'],
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  },
505
+ "tokenizer": CONFIG['t5_model'],
506
+ "transform": {
507
+ "resize": f"{CONFIG['image_size']}x{CONFIG['image_size']}",
508
+ "normalize_mean": [0.485, 0.456, 0.406],
509
+ "normalize_std": [0.229, 0.224, 0.225]
510
  },
511
+ "generation_params": {
512
  "max_length": CONFIG['max_length'],
513
  "num_beams": CONFIG['num_beams'],
514
+ "early_stopping": True,
515
+ "no_extra_penalties": "✓ Exactly as Colab"
516
+ },
517
+ "inference_method": "generate_report() from Colab SECTION 9",
518
+ "models_loaded": {
519
+ "sft": sft_model is not None,
520
+ "ppo": ppo_model is not None
521
+ },
522
+ "model_state": {
523
+ "sft_eval_mode": not sft_model.training if sft_model else None,
524
+ "ppo_eval_mode": not ppo_model.training if ppo_model else None
525
  }
526
  }
527
 
 
538
  print("⚠️ Build directory not found, serving API only")
539
 
540
  print("\n" + "="*80)
541
+ print("SERVER READY - Using EXACT Colab Inference Code")
542
  print("="*80)
543
+ print("Key points:")
544
+ print(" Model architecture: VisionT5Model (exact copy from Colab SECTION 6)")
545
+ print(" ✓ Inference method: generate_report() (exact copy from Colab SECTION 9)")
546
+ print(" ✓ Generation params: max_length=100, num_beams=4, early_stopping=True")
547
+ print(" ✓ No extra penalties: NO repetition_penalty, NO no_repeat_ngram_size")
548
+ print(" Transform: Resize 224x224, Normalize [0.485,0.456,0.406]/[0.229,0.224,0.225]")
549
+ print(" Device handling: Same as Colab")
 
 
 
 
550
  print("="*80)
551
 
552
  if __name__ == "__main__":