Ryanfafa commited on
Commit
dc8ba9c
·
verified ·
1 Parent(s): 83fde1f

Delete image_captioning/app.py

Browse files
Files changed (1) hide show
  1. image_captioning/app.py +0 -60
image_captioning/app.py DELETED
@@ -1,60 +0,0 @@
1
- import io
2
- from typing import List
3
-
4
- from fastapi import FastAPI, File, UploadFile
5
- from fastapi.responses import JSONResponse
6
- from PIL import Image
7
- import torch
8
- from torchvision import transforms
9
-
10
- from image_captioning.config import TrainingConfig, get_device
11
- from image_captioning.dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
12
- from image_captioning.model import ImageCaptioningModel
13
-
14
- app = FastAPI(title="Image Captioning API (HF Space)")
15
-
16
- device = get_device()
17
- training_cfg = TrainingConfig(max_caption_length=50)
18
- tokenizer = create_tokenizer()
19
- model = ImageCaptioningModel(training_cfg=training_cfg)
20
- model.to(device)
21
- model.eval()
22
-
23
- # Load checkpoint from the repo root
24
- CHECKPOINT_PATH = "best_model.pt"
25
- state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
26
- model.load_state_dict(state_dict)
27
-
28
- preprocess = transforms.Compose(
29
- [
30
- transforms.Resize(256),
31
- transforms.CenterCrop(224),
32
- transforms.ToTensor(),
33
- transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
34
- ]
35
- )
36
-
37
-
38
- @app.get("/health")
39
- async def health() -> dict:
40
- return {"status": "ok"}
41
-
42
-
43
- @app.post("/caption")
44
- async def caption_image(file: UploadFile = File(...)) -> JSONResponse:
45
- try:
46
- contents = await file.read()
47
- image = Image.open(io.BytesIO(contents)).convert("RGB")
48
- except Exception as exc:
49
- return JSONResponse(status_code=400, content={"error": f"Invalid image: {exc}"})
50
-
51
- tensor = preprocess(image).unsqueeze(0).to(device)
52
-
53
- with torch.no_grad():
54
- captions: List[str] = model.generate(
55
- images=tensor,
56
- max_length=50,
57
- num_beams=1, # deterministic greedy decoding
58
- )
59
-
60
- return JSONResponse({"caption": captions[0]})