AE-Shree commited on
Commit
1537418
Β·
1 Parent(s): c934b38

Deploy BioStack RLHF Medical Demo

Browse files
Files changed (1) hide show
  1. server.py +148 -2
server.py CHANGED
@@ -337,6 +337,38 @@ def preprocess(file_bytes: bytes) -> torch.Tensor:
337
  return transform(img).unsqueeze(0).to(device) # [1, 3, 224, 224]
338
 
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  # ─────────────────────────────────────────────────────────────────────────────
341
  # REWARD FEEDBACK GENERATOR
342
  # ─────────────────────────────────────────────────────────────────────────────
@@ -390,8 +422,72 @@ def health():
390
  async def sft_inference(file: UploadFile = File(...)):
391
  try:
392
  tensor = preprocess(await file.read())
393
- report = sft_model.generate_reports(tensor)[0]
394
- print(f"[SFT] Generated: {report}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  return {"report": report[:81]}
396
  except Exception as e:
397
  traceback.print_exc()
@@ -474,6 +570,56 @@ async def ppo_inference(file: UploadFile = File(...)):
474
  # DIAGNOSTIC ENDPOINT β€” call GET /debug_keys to verify key names in your files
475
  # e.g. curl http://localhost:8000/debug_keys
476
  # ─────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  @app.get("/debug_keys")
478
  def debug_keys():
479
  import os
 
337
  return transform(img).unsqueeze(0).to(device) # [1, 3, 224, 224]
338
 
339
 
340
+ # ─────────────────────────────────────────────────────────────────────────────
341
+ # DEBUGGING TOOLS - Compare with Jupyter notebook results
342
+ # ─────────────────────────────────────────────────────────────────────────────
343
+ import hashlib
344
+
345
+ def get_model_hash(model):
346
+ """Get hash of model state dict for comparison"""
347
+ model_str = str(model.state_dict())
348
+ return hashlib.md5(model_str.encode()).hexdigest()
349
+
350
+ def log_inference_details(model_name, image_tensor, generated_ids, decoded_report):
351
+ """Detailed logging for debugging inference differences"""
352
+ print(f"\n{'='*50}")
353
+ print(f" {model_name} INFERENCE DEBUG")
354
+ print(f"{'='*50}")
355
+ print(f"Model hash: {get_model_hash(globals()[f'{model_name.lower()}_model'])}")
356
+ print(f"Image tensor shape: {image_tensor.shape}")
357
+ print(f"Image tensor mean: {image_tensor.mean():.6f}")
358
+ print(f"Image tensor std: {image_tensor.std():.6f}")
359
+ print(f"Model in eval mode: {not globals()[f'{model_name.lower()}_model'].training}")
360
+ print(f"Generated IDs: {generated_ids}")
361
+ print(f"Generated IDs shape: {generated_ids.shape}")
362
+ print(f"Decoded report: '{decoded_report}'")
363
+ print(f"Report length: {len(decoded_report)} chars")
364
+ print(f"{'='*50}\n")
365
+
366
+ # Set consistent random seeds for reproducible results
367
+ torch.manual_seed(42)
368
+ torch.cuda.manual_seed_all(42)
369
+ print(" Random seeds set to 42 for reproducible results")
370
+
371
+
372
  # ─────────────────────────────────────────────────────────────────────────────
373
  # REWARD FEEDBACK GENERATOR
374
  # ─────────────────────────────────────────────────────────────────────────────
 
422
  async def sft_inference(file: UploadFile = File(...)):
423
  try:
424
  tensor = preprocess(await file.read())
425
+
426
+ # Enhanced debugging - capture generation details
427
+ print(f"\nπŸ” [SFT] DETAILED INFERENCE ANALYSIS")
428
+ print(f"{'='*60}")
429
+ print(f"Model checkpoint: {SFT_MODEL_PATH}")
430
+ print(f"Image tensor shape: {tensor.shape}")
431
+ print(f"Image tensor device: {tensor.device}")
432
+ print(f"Image tensor mean: {tensor.mean():.6f}")
433
+ print(f"Image tensor std: {tensor.std():.6f}")
434
+ print(f"Model in eval mode: {not sft_model.training}")
435
+ print(f"Using torch.no_grad: True")
436
+
437
+ # Get raw generation output before decoding
438
+ with torch.no_grad():
439
+ img_features = sft_model.img_encoder(tensor)
440
+ img_emb = sft_model.img_proj(img_features).unsqueeze(1)
441
+ batch_size = tensor.size(0)
442
+ img_attn = torch.ones(batch_size, 1, device=tensor.device)
443
+
444
+ encoder_outputs = sft_model.txt_model.encoder(
445
+ inputs_embeds=img_emb,
446
+ attention_mask=img_attn
447
+ )
448
+
449
+ # Log generation parameters
450
+ print(f"Generation parameters:")
451
+ print(f" - max_length: 128")
452
+ print(f" - num_beams: 4")
453
+ print(f" - early_stopping: True")
454
+ print(f" - no_repeat_ngram_size: 3")
455
+ print(f" - repetition_penalty: 1.3")
456
+ print(f" - do_sample: False")
457
+ print(f" - temperature: N/A (deterministic)")
458
+
459
+ generated = sft_model.txt_model.generate(
460
+ encoder_outputs=encoder_outputs,
461
+ attention_mask=img_attn,
462
+ max_length=128,
463
+ num_beams=4,
464
+ early_stopping=True,
465
+ no_repeat_ngram_size=3,
466
+ repetition_penalty=1.3,
467
+ )
468
+
469
+ print(f"Raw generated IDs: {generated}")
470
+ print(f"Generated IDs shape: {generated.shape}")
471
+
472
+ # Decode with same parameters as notebook
473
+ reports = tokenizer.batch_decode(generated, skip_special_tokens=True)
474
+
475
+ # Apply same post-processing
476
+ cleaned_reports = []
477
+ for r in reports:
478
+ if r.lower().startswith("projection:"):
479
+ parts = r.split(".", 1)
480
+ r = parts[1].strip() if len(parts) > 1 else r
481
+ cleaned_reports.append(r)
482
+
483
+ report = cleaned_reports[0]
484
+
485
+ print(f"Decoded report: '{report}'")
486
+ print(f"Report length: {len(report)} chars")
487
+ print(f"Model hash: {get_model_hash(sft_model)}")
488
+ print(f"{'='*60}\n")
489
+
490
+ print(f"[SFT] Final Generated: {report}")
491
  return {"report": report[:81]}
492
  except Exception as e:
493
  traceback.print_exc()
 
570
  # DIAGNOSTIC ENDPOINT β€” call GET /debug_keys to verify key names in your files
571
  # e.g. curl http://localhost:8000/debug_keys
572
  # ─────────────────────────────────────────────────────────────────────────────
573
+ @app.get("/debug_compare")
574
+ def debug_compare():
575
+ """
576
+ Special endpoint to debug inference differences.
577
+ Returns detailed comparison data for troubleshooting.
578
+ """
579
+ import os
580
+
581
+ comparison_data = {
582
+ "server_info": {
583
+ "device": str(device),
584
+ "torch_version": torch.__version__,
585
+ "transformers_version": transformers.__version__,
586
+ "random_seed": 42,
587
+ "models_loaded": {
588
+ "SFT": os.path.basename(SFT_MODEL_PATH),
589
+ "Reward": os.path.basename(REWARD_MODEL_PATH),
590
+ "PPO": os.path.basename(PPO_MODEL_PATH)
591
+ }
592
+ },
593
+ "model_hashes": {
594
+ "SFT": get_model_hash(sft_model),
595
+ "Reward": get_model_hash(reward_model),
596
+ "PPO": get_model_hash(ppo_model)
597
+ },
598
+ "generation_params": {
599
+ "max_length": 128,
600
+ "num_beams": 4,
601
+ "early_stopping": True,
602
+ "no_repeat_ngram_size": 3,
603
+ "repetition_penalty": 1.3,
604
+ "do_sample": False,
605
+ "temperature": "N/A (deterministic)"
606
+ },
607
+ "preprocessing": {
608
+ "resize": [224, 224],
609
+ "normalize_mean": [0.485, 0.456, 0.406],
610
+ "normalize_std": [0.229, 0.224, 0.225],
611
+ "convert": "RGB"
612
+ },
613
+ "model_states": {
614
+ "SFT_eval": not sft_model.training,
615
+ "Reward_eval": not reward_model.training,
616
+ "PPO_eval": not ppo_model.training
617
+ }
618
+ }
619
+
620
+ return comparison_data
621
+
622
+
623
  @app.get("/debug_keys")
624
  def debug_keys():
625
  import os