Shree2604 commited on
Commit
d8c1e26
Β·
verified Β·
1 Parent(s): 02dca55

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +137 -60
server.py CHANGED
@@ -320,34 +320,57 @@ def generate_report(
320
 
321
 
322
  # ─────────────────────────────────────────────────────────────────────────────
323
- # LOAD MODELS FROM HUGGINGFACE
324
  # ─────────────────────────────────────────────────────────────────────────────
325
  print("\n" + "="*80)
326
- print("LOADING MODELS FROM HUGGINGFACE")
327
  print("="*80)
328
 
 
 
 
329
  # Download model files from Hugging Face
330
  try:
 
 
 
 
 
331
  SFT_MODEL_PATH = hf_hub_download(
332
- repo_id="vinaykumarhs2020/RLHF_radiology_model",
333
  filename="best_model.pt"
334
  )
 
 
 
 
 
 
 
 
 
 
 
 
335
  PPO_MODEL_PATH = hf_hub_download(
336
- repo_id="vinaykumarhs2020/RLHF_radiology_model",
337
  filename="rlhf_model.pt"
338
  )
339
- print(f"βœ“ Downloaded SFT model: {SFT_MODEL_PATH}")
340
- print(f"βœ“ Downloaded PPO model: {PPO_MODEL_PATH}")
 
 
341
  except Exception as e:
342
- print(f"❌ Error downloading models: {e}")
343
- # Fallback to local paths if downloads fail
344
- SFT_MODEL_PATH = "/content/best_model.pt"
345
- PPO_MODEL_PATH = "/content/rlhf_model.pt"
346
- print(f"⚠️ Using local paths instead")
 
347
 
348
  # Load both models - EXACTLY as Colab SECTION 8
349
  print("\n" + "="*80)
350
- print("LOADING MODELS")
351
  print("="*80)
352
 
353
  sft_model = load_model_from_checkpoint(
@@ -362,12 +385,17 @@ ppo_model = load_model_from_checkpoint(
362
  CONFIG
363
  )
364
 
365
- print("\nβœ“ Both models loaded successfully!")
 
366
 
367
  # ─────────────────────────────────────────────────────────────────────────────
368
  # FASTAPI APP
369
  # ─────────────────────────────────────────────────────────────────────────────
370
- app = FastAPI(title="Medical Report Generation - Exact Colab Match")
 
 
 
 
371
 
372
  app.add_middleware(
373
  CORSMiddleware,
@@ -377,14 +405,40 @@ app.add_middleware(
377
  )
378
 
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  @app.get("/health")
381
  def health():
382
  return {
383
  "status": "ok",
384
  "device": str(device),
385
  "cuda_available": torch.cuda.is_available(),
386
- "models_loaded": True,
387
- "config": {k: v for k, v in CONFIG.items() if k != 'device'}
 
 
 
 
 
 
 
388
  }
389
 
390
 
@@ -392,6 +446,8 @@ def health():
392
  async def sft_inference(file: UploadFile = File(...)):
393
  """
394
  SFT model inference - Uses EXACT generate_report() function from Colab SECTION 9
 
 
395
  """
396
  try:
397
  # Save uploaded file temporarily
@@ -410,18 +466,24 @@ async def sft_inference(file: UploadFile = File(...)):
410
  return {
411
  "report": report,
412
  "model": "SFT",
413
- "method": "generate_report() - exact Colab SECTION 9"
 
414
  }
415
 
416
  except Exception as e:
417
  traceback.print_exc()
418
- return {"report": f"ERROR: {str(e)}", "model": "SFT"}
 
 
 
419
 
420
 
421
  @app.post("/ppo")
422
  async def ppo_inference(file: UploadFile = File(...)):
423
  """
424
  PPO model inference - Uses EXACT generate_report() function from Colab SECTION 9
 
 
425
  """
426
  try:
427
  # Save uploaded file temporarily
@@ -440,12 +502,16 @@ async def ppo_inference(file: UploadFile = File(...)):
440
  return {
441
  "report": report,
442
  "model": "PPO",
443
- "method": "generate_report() - exact Colab SECTION 9"
 
444
  }
445
 
446
  except Exception as e:
447
  traceback.print_exc()
448
- return {"report": f"ERROR: {str(e)}", "model": "PPO"}
 
 
 
449
 
450
 
451
  @app.post("/compare")
@@ -453,6 +519,8 @@ async def compare_models(file: UploadFile = File(...)):
453
  """
454
  Generate reports from both models for comparison
455
  Uses EXACT generate_report() function from Colab
 
 
456
  """
457
  try:
458
  # Save uploaded file temporarily
@@ -473,8 +541,11 @@ async def compare_models(file: UploadFile = File(...)):
473
  return {
474
  "sft_report": sft_report,
475
  "ppo_report": ppo_report,
476
- "method": "generate_report() - exact Colab SECTION 9",
477
- "config": {k: v for k, v in CONFIG.items() if k != 'device'}
 
 
 
478
  }
479
 
480
  except Exception as e:
@@ -485,43 +556,45 @@ async def compare_models(file: UploadFile = File(...)):
485
  }
486
 
487
 
488
- @app.get("/debug_inference")
489
- def debug_inference():
490
  """
491
- Debug endpoint to verify inference setup matches Colab exactly
492
  """
493
  return {
494
- "device": str(device),
495
- "cuda_available": torch.cuda.is_available(),
496
- "config": {
497
- "coatnet_model": CONFIG['coatnet_model'],
498
- "t5_model": CONFIG['t5_model'],
499
- "img_emb_dim": CONFIG['img_emb_dim'],
500
- "train_last_stages": CONFIG['train_last_stages'],
501
- "image_size": CONFIG['image_size'],
502
- "max_length": CONFIG['max_length'],
503
- "num_beams": CONFIG['num_beams'],
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  },
505
- "tokenizer": CONFIG['t5_model'],
506
- "transform": {
507
- "resize": f"{CONFIG['image_size']}x{CONFIG['image_size']}",
508
- "normalize_mean": [0.485, 0.456, 0.406],
509
- "normalize_std": [0.229, 0.224, 0.225]
510
  },
511
- "generation_params": {
512
  "max_length": CONFIG['max_length'],
513
  "num_beams": CONFIG['num_beams'],
514
- "early_stopping": True,
515
- "no_extra_penalties": "βœ“ Exactly as Colab"
516
- },
517
- "inference_method": "generate_report() from Colab SECTION 9",
518
- "models_loaded": {
519
- "sft": sft_model is not None,
520
- "ppo": ppo_model is not None
521
- },
522
- "model_state": {
523
- "sft_eval_mode": not sft_model.training if sft_model else None,
524
- "ppo_eval_mode": not ppo_model.training if ppo_model else None
525
  }
526
  }
527
 
@@ -538,15 +611,19 @@ else:
538
  print("⚠️ Build directory not found, serving API only")
539
 
540
  print("\n" + "="*80)
541
- print("SERVER READY - Using EXACT Colab Inference Code")
542
  print("="*80)
543
- print("Key points:")
544
- print(" βœ“ Model architecture: VisionT5Model (exact copy from Colab SECTION 6)")
545
- print(" βœ“ Inference method: generate_report() (exact copy from Colab SECTION 9)")
546
- print(" βœ“ Generation params: max_length=100, num_beams=4, early_stopping=True")
547
- print(" βœ“ No extra penalties: NO repetition_penalty, NO no_repeat_ngram_size")
548
- print(" βœ“ Transform: Resize 224x224, Normalize [0.485,0.456,0.406]/[0.229,0.224,0.225]")
549
- print(" βœ“ Device handling: Same as Colab")
 
 
 
 
550
  print("="*80)
551
 
552
  if __name__ == "__main__":
 
320
 
321
 
322
  # ─────────────────────────────────────────────────────────────────────────────
323
+ # LOAD MODELS FROM HUGGINGFACE - Shree2604/BioStack
324
  # ─────────────────────────────────────────────────────────────────────────────
325
  print("\n" + "="*80)
326
+ print("DOWNLOADING MODELS FROM HUGGINGFACE")
327
  print("="*80)
328
 
329
+ # Hugging Face repository
330
+ HF_REPO = "Shree2604/BioStack"
331
+
332
  # Download model files from Hugging Face
333
  try:
334
+ print(f"πŸ“¦ Downloading from repository: {HF_REPO}")
335
+ print("This may take a few minutes on first run...\n")
336
+
337
+ # Download SFT model
338
+ print("1️⃣ Downloading SFT model (best_model.pt)...")
339
  SFT_MODEL_PATH = hf_hub_download(
340
+ repo_id=HF_REPO,
341
  filename="best_model.pt"
342
  )
343
+ print(f" βœ“ SFT model downloaded: {SFT_MODEL_PATH}")
344
+
345
+ # Download Reward model
346
+ print("\n2️⃣ Downloading Reward model (reward_model.pt)...")
347
+ REWARD_MODEL_PATH = hf_hub_download(
348
+ repo_id=HF_REPO,
349
+ filename="reward_model.pt"
350
+ )
351
+ print(f" βœ“ Reward model downloaded: {REWARD_MODEL_PATH}")
352
+
353
+ # Download PPO model
354
+ print("\n3️⃣ Downloading PPO model (rlhf_model.pt)...")
355
  PPO_MODEL_PATH = hf_hub_download(
356
+ repo_id=HF_REPO,
357
  filename="rlhf_model.pt"
358
  )
359
+ print(f" βœ“ PPO model downloaded: {PPO_MODEL_PATH}")
360
+
361
+ print("\nβœ… All models downloaded successfully!")
362
+
363
  except Exception as e:
364
+ print(f"\n❌ Error downloading models from Hugging Face: {e}")
365
+ print("Please check:")
366
+ print(f" - Repository exists: https://huggingface.co/{HF_REPO}")
367
+ print(" - Files exist: best_model.pt, reward_model.pt, rlhf_model.pt")
368
+ print(" - You have internet connection")
369
+ raise
370
 
371
  # Load both models - EXACTLY as Colab SECTION 8
372
  print("\n" + "="*80)
373
+ print("LOADING MODELS INTO MEMORY")
374
  print("="*80)
375
 
376
  sft_model = load_model_from_checkpoint(
 
385
  CONFIG
386
  )
387
 
388
+ print("\nβœ… Both models loaded successfully!")
389
+ print("="*80)
390
 
391
  # ─────────────────────────────────────────────────────────────────────────────
392
  # FASTAPI APP
393
  # ─────────────────────────────────────────────────────────────────────────────
394
+ app = FastAPI(
395
+ title="BioStack Medical Report Generation",
396
+ description="Medical X-ray report generation using SFT and PPO models from Shree2604/BioStack",
397
+ version="1.0.0"
398
+ )
399
 
400
  app.add_middleware(
401
  CORSMiddleware,
 
405
  )
406
 
407
 
408
+ @app.get("/")
409
+ def root():
410
+ return {
411
+ "message": "BioStack Medical Report Generation API",
412
+ "repository": "Shree2604/BioStack",
413
+ "models": {
414
+ "sft": "best_model.pt",
415
+ "ppo": "rlhf_model.pt",
416
+ "reward": "reward_model.pt"
417
+ },
418
+ "endpoints": {
419
+ "health": "GET /health - Check API status",
420
+ "sft": "POST /sft - Generate report using SFT model",
421
+ "ppo": "POST /ppo - Generate report using PPO model",
422
+ "compare": "POST /compare - Compare both models"
423
+ }
424
+ }
425
+
426
+
427
  @app.get("/health")
428
  def health():
429
  return {
430
  "status": "ok",
431
  "device": str(device),
432
  "cuda_available": torch.cuda.is_available(),
433
+ "models_loaded": {
434
+ "sft": sft_model is not None,
435
+ "ppo": ppo_model is not None
436
+ },
437
+ "repository": HF_REPO,
438
+ "model_files": {
439
+ "sft": os.path.basename(SFT_MODEL_PATH),
440
+ "ppo": os.path.basename(PPO_MODEL_PATH)
441
+ }
442
  }
443
 
444
 
 
446
  async def sft_inference(file: UploadFile = File(...)):
447
  """
448
  SFT model inference - Uses EXACT generate_report() function from Colab SECTION 9
449
+
450
+ Model: best_model.pt from Shree2604/BioStack
451
  """
452
  try:
453
  # Save uploaded file temporarily
 
466
  return {
467
  "report": report,
468
  "model": "SFT",
469
+ "source": "best_model.pt",
470
+ "repository": HF_REPO
471
  }
472
 
473
  except Exception as e:
474
  traceback.print_exc()
475
+ return {
476
+ "report": f"ERROR: {str(e)}",
477
+ "model": "SFT"
478
+ }
479
 
480
 
481
  @app.post("/ppo")
482
  async def ppo_inference(file: UploadFile = File(...)):
483
  """
484
  PPO model inference - Uses EXACT generate_report() function from Colab SECTION 9
485
+
486
+ Model: rlhf_model.pt from Shree2604/BioStack
487
  """
488
  try:
489
  # Save uploaded file temporarily
 
502
  return {
503
  "report": report,
504
  "model": "PPO",
505
+ "source": "rlhf_model.pt",
506
+ "repository": HF_REPO
507
  }
508
 
509
  except Exception as e:
510
  traceback.print_exc()
511
+ return {
512
+ "report": f"ERROR: {str(e)}",
513
+ "model": "PPO"
514
+ }
515
 
516
 
517
  @app.post("/compare")
 
519
  """
520
  Generate reports from both models for comparison
521
  Uses EXACT generate_report() function from Colab
522
+
523
+ Models: best_model.pt and rlhf_model.pt from Shree2604/BioStack
524
  """
525
  try:
526
  # Save uploaded file temporarily
 
541
  return {
542
  "sft_report": sft_report,
543
  "ppo_report": ppo_report,
544
+ "models": {
545
+ "sft": "best_model.pt",
546
+ "ppo": "rlhf_model.pt"
547
+ },
548
+ "repository": HF_REPO
549
  }
550
 
551
  except Exception as e:
 
556
  }
557
 
558
 
559
+ @app.get("/model_info")
560
+ def model_info():
561
  """
562
+ Get detailed information about loaded models
563
  """
564
  return {
565
+ "repository": HF_REPO,
566
+ "repository_url": f"https://huggingface.co/{HF_REPO}",
567
+ "models": {
568
+ "sft": {
569
+ "filename": "best_model.pt",
570
+ "url": f"https://huggingface.co/{HF_REPO}/blob/main/best_model.pt",
571
+ "local_path": SFT_MODEL_PATH,
572
+ "loaded": sft_model is not None,
573
+ "in_eval_mode": not sft_model.training if sft_model else None
574
+ },
575
+ "ppo": {
576
+ "filename": "rlhf_model.pt",
577
+ "url": f"https://huggingface.co/{HF_REPO}/blob/main/rlhf_model.pt",
578
+ "local_path": PPO_MODEL_PATH,
579
+ "loaded": ppo_model is not None,
580
+ "in_eval_mode": not ppo_model.training if ppo_model else None
581
+ },
582
+ "reward": {
583
+ "filename": "reward_model.pt",
584
+ "url": f"https://huggingface.co/{HF_REPO}/blob/main/reward_model.pt",
585
+ "local_path": REWARD_MODEL_PATH,
586
+ "note": "Downloaded but not loaded in this API"
587
+ }
588
  },
589
+ "architecture": {
590
+ "vision_encoder": CONFIG['coatnet_model'],
591
+ "text_model": CONFIG['t5_model'],
592
+ "image_embedding_dim": CONFIG['img_emb_dim']
 
593
  },
594
+ "inference_config": {
595
  "max_length": CONFIG['max_length'],
596
  "num_beams": CONFIG['num_beams'],
597
+ "image_size": CONFIG['image_size']
 
 
 
 
 
 
 
 
 
 
598
  }
599
  }
600
 
 
611
  print("⚠️ Build directory not found, serving API only")
612
 
613
  print("\n" + "="*80)
614
+ print("πŸš€ SERVER READY")
615
  print("="*80)
616
+ print(f"Repository: {HF_REPO}")
617
+ print("Models loaded:")
618
+ print(f" βœ“ SFT: best_model.pt")
619
+ print(f" βœ“ PPO: rlhf_model.pt")
620
+ print("\nEndpoints:")
621
+ print(" GET / - API info")
622
+ print(" GET /health - Health check")
623
+ print(" GET /model_info - Model details")
624
+ print(" POST /sft - SFT inference")
625
+ print(" POST /ppo - PPO inference")
626
+ print(" POST /compare - Compare both models")
627
  print("="*80)
628
 
629
  if __name__ == "__main__":