Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| from typing import List | |
| from pydantic import BaseModel | |
| from model_registry import get_model | |
| from models.resnet_lstm_attention.schemas import CaptionResult, ImageResult, TextQuery | |
| app = FastAPI(title="Multimodal Retrieval & Captioning API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class InferenceRequest(BaseModel): | |
| model_name: str | |
| top_k: int = 5 | |
| async def caption_image(model_name: str = Form(...), file: UploadFile = File(...)): | |
| image = Image.open(file.file).convert("RGB") | |
| model = get_model(model_name) | |
| caption = model.generate_caption(image) | |
| return {"caption": caption} | |
| async def text_to_image(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)): | |
| model = get_model(model_name) | |
| results = model.text_to_image(query, top_k) | |
| return results | |
| async def image_to_text(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)): | |
| image = Image.open(file.file).convert("RGB") | |
| model = get_model(model_name) | |
| results = model.image_to_text(image, top_k) | |
| return results | |
| async def image_to_image(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)): | |
| image = Image.open(file.file).convert("RGB") | |
| model = get_model(model_name) | |
| results = model.image_to_image(image, top_k) | |
| return results | |
| async def text_to_text(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)): | |
| model = get_model(model_name) | |
| results = model.text_to_text(query, top_k) | |
| return results | |
| def health_check(): | |
| return {"status": "healthy"} | |
| # # api.py | |
| # from fastapi import FastAPI, UploadFile, File, Form | |
| # from fastapi.middleware.cors import CORSMiddleware | |
| # from PIL import Image | |
| # from typing import List | |
| # from pydantic import BaseModel | |
| # from models.resnet_lstm_attention.loader import load_captioning_model | |
| # from models.resnet_lstm_attention.cap_mod_defs import EncoderCNN | |
| # from model_registry import get_model | |
| # from models.resnet_lstm_attention.schemas import CaptionResult, ImageResult, TextQuery | |
| # app = FastAPI(title="Multimodal Retrieval & Captioning API") | |
| # app.add_middleware( | |
| # CORSMiddleware, | |
| # allow_origins=["*"], | |
| # allow_methods=["*"], | |
| # allow_headers=["*"], | |
| # ) | |
| # class InferenceRequest(BaseModel): | |
| # model_name: str | |
| # top_k: int = 5 | |
| # #@app.post("/caption", response_model=CaptionResult) | |
| # @app.post("/caption") | |
| # async def caption_image(model_name: str = Form(...), file: UploadFile = File(...)): | |
| # image = Image.open(file.file).convert("RGB") | |
| # model = get_model(model_name) | |
| # caption = model.generate_caption(image) | |
| # return {"caption": caption} | |
| # #@app.post("/search/text2img", response_model=List[ImageResult]) | |
| # @app.post("/search/text2img") | |
| # async def text_to_image(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)): | |
| # model = get_model(model_name) | |
| # results = model.text_to_image(query, top_k) | |
| # return results | |
| # @app.post("/search/img2text") | |
| # async def image_to_text(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)): | |
| # image = Image.open(file.file).convert("RGB") | |
| # model = get_model(model_name) | |
| # results = model.image_to_text(image, top_k) | |
| # return results | |
| # #@app.post("/search/img2img", response_model=List[ImageResult]) | |
| # @app.post("/search/img2img") | |
| # async def image_to_image(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)): | |
| # image = Image.open(file.file).convert("RGB") | |
| # model = get_model(model_name) | |
| # results = model.image_to_image(image, top_k) | |
| # return results | |
| # @app.post("/search/text2text") | |
| # async def text_to_text(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)): | |
| # model = get_model(model_name) | |
| # results = model.text_to_text(query, top_k) | |
| # return results | |
| # @app.get("/health") | |
| # def health_check(): | |
| # return {"status": "healthy"} |