| from fastapi import FastAPI, File, UploadFile, Form |
| from fastapi.responses import JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from PIL import Image |
| import io |
| import torch |
| from clip_interrogator import Config, Interrogator |
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| config = Config() |
| config.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| config.blip_offload = False if torch.cuda.is_available() else True |
| config.chunk_size = 2048 |
| config.flavor_intermediate_count = 512 |
| config.blip_num_beams = 64 |
|
|
| ci = Interrogator(config) |
|
|
| @app.post("/inference/") |
| async def interrogate_images(file: UploadFile = File(...), mode: str = Form(...), best_max_flavors: int = Form(...)): |
| try: |
| contents = await file.read() |
| image = Image.open(io.BytesIO(contents)).convert('RGB') |
| |
| if mode == 'best': |
| prompt_result = ci.interrogate(image, max_flavors=int(best_max_flavors)) |
| elif mode == 'classic': |
| prompt_result = ci.interrogate_classic(image) |
| else: |
| prompt_result = ci.interrogate_fast(image) |
| |
| return JSONResponse(content={"prompt_results": [prompt_result]}) |
| except Exception as e: |
| return JSONResponse(content={"error": str(e)}, status_code=500) |
|
|
|
|