Ryanfafa commited on
Commit
3c4d3f7
·
verified ·
1 Parent(s): c943da6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import List
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi import FastAPI, File, UploadFile
5
+ from fastapi.responses import JSONResponse
6
+ from PIL import Image
7
+ import torch
8
+ from torchvision import transforms
9
+
10
+ from image_captioning.config import TrainingConfig, get_device
11
+ from image_captioning.dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
12
+ from image_captioning.model import ImageCaptioningModel
13
+
14
+ app = FastAPI(title="Image Captioning API (HF Space)")
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"], # Allows all domains. For security, replace with your GitHub Pages URL later.
18
+ allow_credentials=True,
19
+ allow_methods=["*"],
20
+ allow_headers=["*"],
21
+ )
22
+
23
+ @app.get("/")
24
+ async def root():
25
+ return {"message": "Image Captioning API is running. Use /docs for the UI or POST /caption for captions."}
26
+
27
+ @app.post("/caption")
28
+ async def get_caption(file: UploadFile = File(...)):
29
+ # Your existing logic to process the image and generate a caption
30
+ # result = model.predict(image)
31
+ return {"caption": "The generated caption text here"}
32
+
33
+ device = get_device()
34
+ training_cfg = TrainingConfig(max_caption_length=50)
35
+ tokenizer = create_tokenizer()
36
+ model = ImageCaptioningModel(training_cfg=training_cfg)
37
+ model.to(device)
38
+ model.eval()
39
+
40
+ CHECKPOINT_PATH = "best_model.pt"
41
+ state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
42
+ model.load_state_dict(state_dict)
43
+
44
+ preprocess = transforms.Compose(
45
+ [
46
+ transforms.Resize(256),
47
+ transforms.CenterCrop(224),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
50
+ ]
51
+ )
52
+
53
+
54
+ @app.get("/health")
55
+ async def health() -> dict:
56
+ return {"status": "ok"}
57
+
58
+
59
+ @app.post("/caption")
60
+ async def caption_image(file: UploadFile = File(...)) -> JSONResponse:
61
+ try:
62
+ contents = await file.read()
63
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
64
+ except Exception as exc:
65
+ return JSONResponse(status_code=400, content={"error": f"Invalid image: {exc}"})
66
+
67
+ tensor = preprocess(image).unsqueeze(0).to(device)
68
+
69
+ with torch.no_grad():
70
+ captions: List[str] = model.generate(
71
+ images=tensor,
72
+ max_length=50,
73
+ num_beams=1,
74
+ )
75
+
76
+ return JSONResponse({"caption": captions[0]})