Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import io | |
| import torch | |
| import pickle | |
| import os | |
| import uvicorn | |
| from huggingface_hub import hf_hub_download | |
| # Import from model.py | |
| from model import ( | |
| Vocabulary, | |
| ResNetEncoder, | |
| DecoderLSTM, | |
| ImageCaptioningModel, | |
| generate_caption, | |
| transform, | |
| EMBED_DIM, | |
| HIDDEN_DIM, | |
| ) | |
| app = FastAPI(title="Image Captioning API") | |
| # ------------------------- | |
| # Enable CORS | |
| # ------------------------- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ------------------------- | |
| # Paths | |
| # ------------------------- | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| VOCAB_PATH = os.path.join(BASE_DIR, "vocab.pkl") | |
| CHECKPOINT_PATH = hf_hub_download( | |
| repo_id="VIKRAM989/image-label", | |
| filename="best_checkpoint.pth" | |
| ) | |
| # ------------------------- | |
| # Load Vocabulary | |
| # ------------------------- | |
| class CustomUnpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| if name == "Vocabulary": | |
| return Vocabulary | |
| return super().find_class(module, name) | |
| with open(VOCAB_PATH, "rb") as f: | |
| vocab = CustomUnpickler(f).load() | |
| vocab_size = len(vocab) | |
| # ------------------------- | |
| # Build Model | |
| # ------------------------- | |
| encoder = ResNetEncoder(EMBED_DIM) | |
| decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size) | |
| model = ImageCaptioningModel(encoder, decoder).to(DEVICE) | |
| # ------------------------- | |
| # Load Weights | |
| # ------------------------- | |
| checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() | |
| print("✅ Model Loaded Successfully") | |
| # ------------------------- | |
| # Health Check | |
| # ------------------------- | |
| def root(): | |
| return {"message": "Image Captioning API Running"} | |
| # ------------------------- | |
| # Caption Endpoint | |
| # ------------------------- | |
| async def caption_image(file: UploadFile = File(...)): | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| image = transform(image) | |
| caption = generate_caption(model, image, vocab) | |
| return {"caption": caption} | |
| if __name__ == "__main__": | |
| uvicorn.run("main:app", host="0.0.0.0", port=7860) |