AE-Shree commited on
Commit Β·
1537418
1
Parent(s): c934b38
Deploy BioStack RLHF Medical Demo
Browse files
server.py
CHANGED
|
@@ -337,6 +337,38 @@ def preprocess(file_bytes: bytes) -> torch.Tensor:
|
|
| 337 |
return transform(img).unsqueeze(0).to(device) # [1, 3, 224, 224]
|
| 338 |
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 341 |
# REWARD FEEDBACK GENERATOR
|
| 342 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -390,8 +422,72 @@ def health():
|
|
| 390 |
async def sft_inference(file: UploadFile = File(...)):
|
| 391 |
try:
|
| 392 |
tensor = preprocess(await file.read())
|
| 393 |
-
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
return {"report": report[:81]}
|
| 396 |
except Exception as e:
|
| 397 |
traceback.print_exc()
|
|
@@ -474,6 +570,56 @@ async def ppo_inference(file: UploadFile = File(...)):
|
|
| 474 |
# DIAGNOSTIC ENDPOINT β call GET /debug_keys to verify key names in your files
|
| 475 |
# e.g. curl http://localhost:8000/debug_keys
|
| 476 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
@app.get("/debug_keys")
|
| 478 |
def debug_keys():
|
| 479 |
import os
|
|
|
|
| 337 |
return transform(img).unsqueeze(0).to(device) # [1, 3, 224, 224]
|
| 338 |
|
| 339 |
|
| 340 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 341 |
+
# DEBUGGING TOOLS - Compare with Jupyter notebook results
|
| 342 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 343 |
+
import hashlib
|
| 344 |
+
|
| 345 |
+
def get_model_hash(model):
|
| 346 |
+
"""Get hash of model state dict for comparison"""
|
| 347 |
+
model_str = str(model.state_dict())
|
| 348 |
+
return hashlib.md5(model_str.encode()).hexdigest()
|
| 349 |
+
|
| 350 |
+
def log_inference_details(model_name, image_tensor, generated_ids, decoded_report):
|
| 351 |
+
"""Detailed logging for debugging inference differences"""
|
| 352 |
+
print(f"\n{'='*50}")
|
| 353 |
+
print(f" {model_name} INFERENCE DEBUG")
|
| 354 |
+
print(f"{'='*50}")
|
| 355 |
+
print(f"Model hash: {get_model_hash(globals()[f'{model_name.lower()}_model'])}")
|
| 356 |
+
print(f"Image tensor shape: {image_tensor.shape}")
|
| 357 |
+
print(f"Image tensor mean: {image_tensor.mean():.6f}")
|
| 358 |
+
print(f"Image tensor std: {image_tensor.std():.6f}")
|
| 359 |
+
print(f"Model in eval mode: {not globals()[f'{model_name.lower()}_model'].training}")
|
| 360 |
+
print(f"Generated IDs: {generated_ids}")
|
| 361 |
+
print(f"Generated IDs shape: {generated_ids.shape}")
|
| 362 |
+
print(f"Decoded report: '{decoded_report}'")
|
| 363 |
+
print(f"Report length: {len(decoded_report)} chars")
|
| 364 |
+
print(f"{'='*50}\n")
|
| 365 |
+
|
| 366 |
+
# Set consistent random seeds for reproducible results
|
| 367 |
+
torch.manual_seed(42)
|
| 368 |
+
torch.cuda.manual_seed_all(42)
|
| 369 |
+
print(" Random seeds set to 42 for reproducible results")
|
| 370 |
+
|
| 371 |
+
|
| 372 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 373 |
# REWARD FEEDBACK GENERATOR
|
| 374 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 422 |
async def sft_inference(file: UploadFile = File(...)):
|
| 423 |
try:
|
| 424 |
tensor = preprocess(await file.read())
|
| 425 |
+
|
| 426 |
+
# Enhanced debugging - capture generation details
|
| 427 |
+
print(f"\nπ [SFT] DETAILED INFERENCE ANALYSIS")
|
| 428 |
+
print(f"{'='*60}")
|
| 429 |
+
print(f"Model checkpoint: {SFT_MODEL_PATH}")
|
| 430 |
+
print(f"Image tensor shape: {tensor.shape}")
|
| 431 |
+
print(f"Image tensor device: {tensor.device}")
|
| 432 |
+
print(f"Image tensor mean: {tensor.mean():.6f}")
|
| 433 |
+
print(f"Image tensor std: {tensor.std():.6f}")
|
| 434 |
+
print(f"Model in eval mode: {not sft_model.training}")
|
| 435 |
+
print(f"Using torch.no_grad: True")
|
| 436 |
+
|
| 437 |
+
# Get raw generation output before decoding
|
| 438 |
+
with torch.no_grad():
|
| 439 |
+
img_features = sft_model.img_encoder(tensor)
|
| 440 |
+
img_emb = sft_model.img_proj(img_features).unsqueeze(1)
|
| 441 |
+
batch_size = tensor.size(0)
|
| 442 |
+
img_attn = torch.ones(batch_size, 1, device=tensor.device)
|
| 443 |
+
|
| 444 |
+
encoder_outputs = sft_model.txt_model.encoder(
|
| 445 |
+
inputs_embeds=img_emb,
|
| 446 |
+
attention_mask=img_attn
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
# Log generation parameters
|
| 450 |
+
print(f"Generation parameters:")
|
| 451 |
+
print(f" - max_length: 128")
|
| 452 |
+
print(f" - num_beams: 4")
|
| 453 |
+
print(f" - early_stopping: True")
|
| 454 |
+
print(f" - no_repeat_ngram_size: 3")
|
| 455 |
+
print(f" - repetition_penalty: 1.3")
|
| 456 |
+
print(f" - do_sample: False")
|
| 457 |
+
print(f" - temperature: N/A (deterministic)")
|
| 458 |
+
|
| 459 |
+
generated = sft_model.txt_model.generate(
|
| 460 |
+
encoder_outputs=encoder_outputs,
|
| 461 |
+
attention_mask=img_attn,
|
| 462 |
+
max_length=128,
|
| 463 |
+
num_beams=4,
|
| 464 |
+
early_stopping=True,
|
| 465 |
+
no_repeat_ngram_size=3,
|
| 466 |
+
repetition_penalty=1.3,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
print(f"Raw generated IDs: {generated}")
|
| 470 |
+
print(f"Generated IDs shape: {generated.shape}")
|
| 471 |
+
|
| 472 |
+
# Decode with same parameters as notebook
|
| 473 |
+
reports = tokenizer.batch_decode(generated, skip_special_tokens=True)
|
| 474 |
+
|
| 475 |
+
# Apply same post-processing
|
| 476 |
+
cleaned_reports = []
|
| 477 |
+
for r in reports:
|
| 478 |
+
if r.lower().startswith("projection:"):
|
| 479 |
+
parts = r.split(".", 1)
|
| 480 |
+
r = parts[1].strip() if len(parts) > 1 else r
|
| 481 |
+
cleaned_reports.append(r)
|
| 482 |
+
|
| 483 |
+
report = cleaned_reports[0]
|
| 484 |
+
|
| 485 |
+
print(f"Decoded report: '{report}'")
|
| 486 |
+
print(f"Report length: {len(report)} chars")
|
| 487 |
+
print(f"Model hash: {get_model_hash(sft_model)}")
|
| 488 |
+
print(f"{'='*60}\n")
|
| 489 |
+
|
| 490 |
+
print(f"[SFT] Final Generated: {report}")
|
| 491 |
return {"report": report[:81]}
|
| 492 |
except Exception as e:
|
| 493 |
traceback.print_exc()
|
|
|
|
| 570 |
# DIAGNOSTIC ENDPOINT β call GET /debug_keys to verify key names in your files
|
| 571 |
# e.g. curl http://localhost:8000/debug_keys
|
| 572 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 573 |
+
@app.get("/debug_compare")
|
| 574 |
+
def debug_compare():
|
| 575 |
+
"""
|
| 576 |
+
Special endpoint to debug inference differences.
|
| 577 |
+
Returns detailed comparison data for troubleshooting.
|
| 578 |
+
"""
|
| 579 |
+
import os
|
| 580 |
+
|
| 581 |
+
comparison_data = {
|
| 582 |
+
"server_info": {
|
| 583 |
+
"device": str(device),
|
| 584 |
+
"torch_version": torch.__version__,
|
| 585 |
+
"transformers_version": transformers.__version__,
|
| 586 |
+
"random_seed": 42,
|
| 587 |
+
"models_loaded": {
|
| 588 |
+
"SFT": os.path.basename(SFT_MODEL_PATH),
|
| 589 |
+
"Reward": os.path.basename(REWARD_MODEL_PATH),
|
| 590 |
+
"PPO": os.path.basename(PPO_MODEL_PATH)
|
| 591 |
+
}
|
| 592 |
+
},
|
| 593 |
+
"model_hashes": {
|
| 594 |
+
"SFT": get_model_hash(sft_model),
|
| 595 |
+
"Reward": get_model_hash(reward_model),
|
| 596 |
+
"PPO": get_model_hash(ppo_model)
|
| 597 |
+
},
|
| 598 |
+
"generation_params": {
|
| 599 |
+
"max_length": 128,
|
| 600 |
+
"num_beams": 4,
|
| 601 |
+
"early_stopping": True,
|
| 602 |
+
"no_repeat_ngram_size": 3,
|
| 603 |
+
"repetition_penalty": 1.3,
|
| 604 |
+
"do_sample": False,
|
| 605 |
+
"temperature": "N/A (deterministic)"
|
| 606 |
+
},
|
| 607 |
+
"preprocessing": {
|
| 608 |
+
"resize": [224, 224],
|
| 609 |
+
"normalize_mean": [0.485, 0.456, 0.406],
|
| 610 |
+
"normalize_std": [0.229, 0.224, 0.225],
|
| 611 |
+
"convert": "RGB"
|
| 612 |
+
},
|
| 613 |
+
"model_states": {
|
| 614 |
+
"SFT_eval": not sft_model.training,
|
| 615 |
+
"Reward_eval": not reward_model.training,
|
| 616 |
+
"PPO_eval": not ppo_model.training
|
| 617 |
+
}
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
return comparison_data
|
| 621 |
+
|
| 622 |
+
|
| 623 |
@app.get("/debug_keys")
|
| 624 |
def debug_keys():
|
| 625 |
import os
|