Cintd / main.py
Luisgust's picture
Create main.py
2269978 verified
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import torch
from clip_interrogator import Config, Interrogator
app = FastAPI()
# Allow CORS for all origins (adjust as needed for production)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Setup the CLIP Interrogator
config = Config()
config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
config.blip_offload = False if torch.cuda.is_available() else True
config.chunk_size = 2048
config.flavor_intermediate_count = 512
config.blip_num_beams = 64
ci = Interrogator(config)
@app.post("/inference/")
async def interrogate_images(file: UploadFile = File(...), mode: str = Form(...), best_max_flavors: int = Form(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert('RGB')
if mode == 'best':
prompt_result = ci.interrogate(image, max_flavors=int(best_max_flavors))
elif mode == 'classic':
prompt_result = ci.interrogate_classic(image)
else:
prompt_result = ci.interrogate_fast(image)
return JSONResponse(content={"prompt_results": [prompt_result]})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)