Ryanfafa commited on
Commit
37ad06e
·
verified ·
1 Parent(s): 8210416

Create app.py

Browse files
Files changed (1) hide show
  1. image_captioning/app.py +60 -0
image_captioning/app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import List
3
+
4
+ from fastapi import FastAPI, File, UploadFile
5
+ from fastapi.responses import JSONResponse
6
+ from PIL import Image
7
+ import torch
8
+ from torchvision import transforms
9
+
10
+ from image_captioning.config import TrainingConfig, get_device
11
+ from image_captioning.dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
12
+ from image_captioning.model import ImageCaptioningModel
13
+
14
+ app = FastAPI(title="Image Captioning API (HF Space)")
15
+
16
+ device = get_device()
17
+ training_cfg = TrainingConfig(max_caption_length=50)
18
+ tokenizer = create_tokenizer()
19
+ model = ImageCaptioningModel(training_cfg=training_cfg)
20
+ model.to(device)
21
+ model.eval()
22
+
23
+ # Load checkpoint from the repo root
24
+ CHECKPOINT_PATH = "best_model.pt"
25
+ state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
26
+ model.load_state_dict(state_dict)
27
+
28
+ preprocess = transforms.Compose(
29
+ [
30
+ transforms.Resize(256),
31
+ transforms.CenterCrop(224),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
34
+ ]
35
+ )
36
+
37
+
38
+ @app.get("/health")
39
+ async def health() -> dict:
40
+ return {"status": "ok"}
41
+
42
+
43
+ @app.post("/caption")
44
+ async def caption_image(file: UploadFile = File(...)) -> JSONResponse:
45
+ try:
46
+ contents = await file.read()
47
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
48
+ except Exception as exc:
49
+ return JSONResponse(status_code=400, content={"error": f"Invalid image: {exc}"})
50
+
51
+ tensor = preprocess(image).unsqueeze(0).to(device)
52
+
53
+ with torch.no_grad():
54
+ captions: List[str] = model.generate(
55
+ images=tensor,
56
+ max_length=50,
57
+ num_beams=1, # deterministic greedy decoding
58
+ )
59
+
60
+ return JSONResponse({"caption": captions[0]})