Ryanfafa commited on
Commit
040b1a3
·
verified ·
1 Parent(s): 3c4d3f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -32
app.py CHANGED
@@ -11,66 +11,65 @@ from image_captioning.config import TrainingConfig, get_device
11
  from image_captioning.dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
12
  from image_captioning.model import ImageCaptioningModel
13
 
 
14
  app = FastAPI(title="Image Captioning API (HF Space)")
 
15
  app.add_middleware(
16
  CORSMiddleware,
17
- allow_origins=["*"], # Allows all domains. For security, replace with your GitHub Pages URL later.
18
  allow_credentials=True,
19
  allow_methods=["*"],
20
  allow_headers=["*"],
21
  )
22
 
23
- @app.get("/")
24
- async def root():
25
- return {"message": "Image Captioning API is running. Use /docs for the UI or POST /caption for captions."}
26
-
27
- @app.post("/caption")
28
- async def get_caption(file: UploadFile = File(...)):
29
- # Your existing logic to process the image and generate a caption
30
- # result = model.predict(image)
31
- return {"caption": "The generated caption text here"}
32
-
33
  device = get_device()
34
  training_cfg = TrainingConfig(max_caption_length=50)
35
  tokenizer = create_tokenizer()
36
  model = ImageCaptioningModel(training_cfg=training_cfg)
37
- model.to(device)
38
- model.eval()
39
 
 
40
  CHECKPOINT_PATH = "best_model.pt"
41
  state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
42
  model.load_state_dict(state_dict)
 
 
43
 
44
- preprocess = transforms.Compose(
45
- [
46
- transforms.Resize(256),
47
- transforms.CenterCrop(224),
48
- transforms.ToTensor(),
49
- transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
50
- ]
51
- )
52
 
 
 
 
 
53
 
54
  @app.get("/health")
55
  async def health() -> dict:
56
  return {"status": "ok"}
57
 
58
-
59
  @app.post("/caption")
60
  async def caption_image(file: UploadFile = File(...)) -> JSONResponse:
61
  try:
62
  contents = await file.read()
63
  image = Image.open(io.BytesIO(contents)).convert("RGB")
64
- except Exception as exc:
65
- return JSONResponse(status_code=400, content={"error": f"Invalid image: {exc}"})
 
66
 
67
- tensor = preprocess(image).unsqueeze(0).to(device)
 
 
 
 
 
 
68
 
69
- with torch.no_grad():
70
- captions: List[str] = model.generate(
71
- images=tensor,
72
- max_length=50,
73
- num_beams=1,
74
- )
75
 
76
- return JSONResponse({"caption": captions[0]})
 
 
11
  from image_captioning.dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
12
  from image_captioning.model import ImageCaptioningModel
13
 
14
+ # 1. Initialize App and CORS
15
  app = FastAPI(title="Image Captioning API (HF Space)")
16
+
17
  app.add_middleware(
18
  CORSMiddleware,
19
+ allow_origins=["*"],
20
  allow_credentials=True,
21
  allow_methods=["*"],
22
  allow_headers=["*"],
23
  )
24
 
25
+ # 2. Load Model & Assets (Global Scope)
 
 
 
 
 
 
 
 
 
26
  device = get_device()
27
  training_cfg = TrainingConfig(max_caption_length=50)
28
  tokenizer = create_tokenizer()
29
  model = ImageCaptioningModel(training_cfg=training_cfg)
 
 
30
 
31
+ # Load weights
32
  CHECKPOINT_PATH = "best_model.pt"
33
  state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
34
  model.load_state_dict(state_dict)
35
+ model.to(device)
36
+ model.eval()
37
 
38
+ # 3. Preprocessing Pipeline
39
+ preprocess = transforms.Compose([
40
+ transforms.Resize(256),
41
+ transforms.CenterCrop(224),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
44
+ ])
 
45
 
46
+ # 4. API Routes
47
+ @app.get("/")
48
+ async def root():
49
+ return {"message": "API is online. Go to /docs for testing."}
50
 
51
  @app.get("/health")
52
  async def health() -> dict:
53
  return {"status": "ok"}
54
 
 
55
  @app.post("/caption")
56
  async def caption_image(file: UploadFile = File(...)) -> JSONResponse:
57
  try:
58
  contents = await file.read()
59
  image = Image.open(io.BytesIO(contents)).convert("RGB")
60
+
61
+ # Preprocess and Move to Device
62
+ tensor = preprocess(image).unsqueeze(0).to(device)
63
 
64
+ # Inference
65
+ with torch.no_grad():
66
+ captions: List[str] = model.generate(
67
+ images=tensor,
68
+ max_length=50,
69
+ num_beams=1,
70
+ )
71
 
72
+ return JSONResponse({"caption": captions[0]})
 
 
 
 
 
73
 
74
+ except Exception as exc:
75
+ return JSONResponse(status_code=400, content={"error": f"Internal Error: {exc}"})