File size: 4,312 Bytes
c33d894 d844186 c33d894 5d50d54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from typing import List
from pydantic import BaseModel
from model_registry import get_model
from models.resnet_lstm_attention.schemas import CaptionResult, ImageResult, TextQuery
app = FastAPI(title="Multimodal Retrieval & Captioning API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class InferenceRequest(BaseModel):
model_name: str
top_k: int = 5
@app.post("/caption")
async def caption_image(model_name: str = Form(...), file: UploadFile = File(...)):
image = Image.open(file.file).convert("RGB")
model = get_model(model_name)
caption = model.generate_caption(image)
return {"caption": caption}
@app.post("/search/text2img")
async def text_to_image(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
model = get_model(model_name)
results = model.text_to_image(query, top_k)
return results
@app.post("/search/img2text")
async def image_to_text(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
image = Image.open(file.file).convert("RGB")
model = get_model(model_name)
results = model.image_to_text(image, top_k)
return results
@app.post("/search/img2img")
async def image_to_image(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
image = Image.open(file.file).convert("RGB")
model = get_model(model_name)
results = model.image_to_image(image, top_k)
return results
@app.post("/search/text2text")
async def text_to_text(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
model = get_model(model_name)
results = model.text_to_text(query, top_k)
return results
@app.get("/health")
def health_check():
return {"status": "healthy"}
# # api.py
# from fastapi import FastAPI, UploadFile, File, Form
# from fastapi.middleware.cors import CORSMiddleware
# from PIL import Image
# from typing import List
# from pydantic import BaseModel
# from models.resnet_lstm_attention.loader import load_captioning_model
# from models.resnet_lstm_attention.cap_mod_defs import EncoderCNN
# from model_registry import get_model
# from models.resnet_lstm_attention.schemas import CaptionResult, ImageResult, TextQuery
# app = FastAPI(title="Multimodal Retrieval & Captioning API")
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_methods=["*"],
# allow_headers=["*"],
# )
# class InferenceRequest(BaseModel):
# model_name: str
# top_k: int = 5
# #@app.post("/caption", response_model=CaptionResult)
# @app.post("/caption")
# async def caption_image(model_name: str = Form(...), file: UploadFile = File(...)):
# image = Image.open(file.file).convert("RGB")
# model = get_model(model_name)
# caption = model.generate_caption(image)
# return {"caption": caption}
# #@app.post("/search/text2img", response_model=List[ImageResult])
# @app.post("/search/text2img")
# async def text_to_image(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
# model = get_model(model_name)
# results = model.text_to_image(query, top_k)
# return results
# @app.post("/search/img2text")
# async def image_to_text(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
# image = Image.open(file.file).convert("RGB")
# model = get_model(model_name)
# results = model.image_to_text(image, top_k)
# return results
# #@app.post("/search/img2img", response_model=List[ImageResult])
# @app.post("/search/img2img")
# async def image_to_image(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
# image = Image.open(file.file).convert("RGB")
# model = get_model(model_name)
# results = model.image_to_image(image, top_k)
# return results
# @app.post("/search/text2text")
# async def text_to_text(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
# model = get_model(model_name)
# results = model.text_to_text(query, top_k)
# return results
# @app.get("/health")
# def health_check():
# return {"status": "healthy"} |