Ryanfafa commited on
Commit
90ac32f
·
verified ·
1 Parent(s): 35a1f66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -2
app.py CHANGED
@@ -20,7 +20,6 @@ model = ImageCaptioningModel(training_cfg=training_cfg)
20
  model.to(device)
21
  model.eval()
22
 
23
- # Load checkpoint from the repo root
24
  CHECKPOINT_PATH = "best_model.pt"
25
  state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
26
  model.load_state_dict(state_dict)
@@ -54,7 +53,7 @@ async def caption_image(file: UploadFile = File(...)) -> JSONResponse:
54
  captions: List[str] = model.generate(
55
  images=tensor,
56
  max_length=50,
57
- num_beams=1, # deterministic greedy decoding
58
  )
59
 
60
  return JSONResponse({"caption": captions[0]})
 
20
  model.to(device)
21
  model.eval()
22
 
 
23
  CHECKPOINT_PATH = "best_model.pt"
24
  state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
25
  model.load_state_dict(state_dict)
 
53
  captions: List[str] = model.generate(
54
  images=tensor,
55
  max_length=50,
56
+ num_beams=1,
57
  )
58
 
59
  return JSONResponse({"caption": captions[0]})