skodan's picture
fixing incorrect references
5d50d54
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from typing import List
from pydantic import BaseModel
from model_registry import get_model
from models.resnet_lstm_attention.schemas import CaptionResult, ImageResult, TextQuery
app = FastAPI(title="Multimodal Retrieval & Captioning API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class InferenceRequest(BaseModel):
model_name: str
top_k: int = 5
@app.post("/caption")
async def caption_image(model_name: str = Form(...), file: UploadFile = File(...)):
image = Image.open(file.file).convert("RGB")
model = get_model(model_name)
caption = model.generate_caption(image)
return {"caption": caption}
@app.post("/search/text2img")
async def text_to_image(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
model = get_model(model_name)
results = model.text_to_image(query, top_k)
return results
@app.post("/search/img2text")
async def image_to_text(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
image = Image.open(file.file).convert("RGB")
model = get_model(model_name)
results = model.image_to_text(image, top_k)
return results
@app.post("/search/img2img")
async def image_to_image(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
image = Image.open(file.file).convert("RGB")
model = get_model(model_name)
results = model.image_to_image(image, top_k)
return results
@app.post("/search/text2text")
async def text_to_text(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
model = get_model(model_name)
results = model.text_to_text(query, top_k)
return results
@app.get("/health")
def health_check():
return {"status": "healthy"}
# # api.py
# from fastapi import FastAPI, UploadFile, File, Form
# from fastapi.middleware.cors import CORSMiddleware
# from PIL import Image
# from typing import List
# from pydantic import BaseModel
# from models.resnet_lstm_attention.loader import load_captioning_model
# from models.resnet_lstm_attention.cap_mod_defs import EncoderCNN
# from model_registry import get_model
# from models.resnet_lstm_attention.schemas import CaptionResult, ImageResult, TextQuery
# app = FastAPI(title="Multimodal Retrieval & Captioning API")
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_methods=["*"],
# allow_headers=["*"],
# )
# class InferenceRequest(BaseModel):
# model_name: str
# top_k: int = 5
# #@app.post("/caption", response_model=CaptionResult)
# @app.post("/caption")
# async def caption_image(model_name: str = Form(...), file: UploadFile = File(...)):
# image = Image.open(file.file).convert("RGB")
# model = get_model(model_name)
# caption = model.generate_caption(image)
# return {"caption": caption}
# #@app.post("/search/text2img", response_model=List[ImageResult])
# @app.post("/search/text2img")
# async def text_to_image(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
# model = get_model(model_name)
# results = model.text_to_image(query, top_k)
# return results
# @app.post("/search/img2text")
# async def image_to_text(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
# image = Image.open(file.file).convert("RGB")
# model = get_model(model_name)
# results = model.image_to_text(image, top_k)
# return results
# #@app.post("/search/img2img", response_model=List[ImageResult])
# @app.post("/search/img2img")
# async def image_to_image(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
# image = Image.open(file.file).convert("RGB")
# model = get_model(model_name)
# results = model.image_to_image(image, top_k)
# return results
# @app.post("/search/text2text")
# async def text_to_text(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
# model = get_model(model_name)
# results = model.text_to_text(query, top_k)
# return results
# @app.get("/health")
# def health_check():
# return {"status": "healthy"}