Ryanfafa commited on
Commit
c943da6
·
verified ·
1 Parent(s): 5b01047

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -76
app.py DELETED
@@ -1,76 +0,0 @@
1
- import io
2
- from typing import List
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi import FastAPI, File, UploadFile
5
- from fastapi.responses import JSONResponse
6
- from PIL import Image
7
- import torch
8
- from torchvision import transforms
9
-
10
- from image_captioning.config import TrainingConfig, get_device
11
- from image_captioning.dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
12
- from image_captioning.model import ImageCaptioningModel
13
-
14
- app = FastAPI(title="Image Captioning API (HF Space)")
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["*"], # Allows all domains. For security, replace with your GitHub Pages URL later.
18
- allow_credentials=True,
19
- allow_methods=["*"],
20
- allow_headers=["*"],
21
- )
22
-
23
- @app.get("/")
24
- async def root():
25
- return {"message": "Image Captioning API is running. Use /docs for the UI or POST /caption for captions."}
26
-
27
- @app.post("/caption")
28
- async def get_caption(file: UploadFile = File(...)):
29
- # Your existing logic to process the image and generate a caption
30
- # result = model.predict(image)
31
- return {"caption": "The generated caption text here"}
32
-
33
- device = get_device()
34
- training_cfg = TrainingConfig(max_caption_length=50)
35
- tokenizer = create_tokenizer()
36
- model = ImageCaptioningModel(training_cfg=training_cfg)
37
- model.to(device)
38
- model.eval()
39
-
40
- CHECKPOINT_PATH = "best_model.pt"
41
- state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
42
- model.load_state_dict(state_dict)
43
-
44
- preprocess = transforms.Compose(
45
- [
46
- transforms.Resize(256),
47
- transforms.CenterCrop(224),
48
- transforms.ToTensor(),
49
- transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
50
- ]
51
- )
52
-
53
-
54
- @app.get("/health")
55
- async def health() -> dict:
56
- return {"status": "ok"}
57
-
58
-
59
- @app.post("/caption")
60
- async def caption_image(file: UploadFile = File(...)) -> JSONResponse:
61
- try:
62
- contents = await file.read()
63
- image = Image.open(io.BytesIO(contents)).convert("RGB")
64
- except Exception as exc:
65
- return JSONResponse(status_code=400, content={"error": f"Invalid image: {exc}"})
66
-
67
- tensor = preprocess(image).unsqueeze(0).to(device)
68
-
69
- with torch.no_grad():
70
- captions: List[str] = model.generate(
71
- images=tensor,
72
- max_length=50,
73
- num_beams=1,
74
- )
75
-
76
- return JSONResponse({"caption": captions[0]})