apiserver / app.py
Semnykcz's picture
Upload 3 files
0bea704 verified
# app.py
import io, os, threading
from typing import Optional, Any, Dict
from PIL import Image
import torch
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
# ── CORS ───────────────────────────────────────────────────────────────────────
ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "*").split(",")
ALLOW_CREDENTIALS = not (len(ALLOWED_ORIGINS) == 1 and ALLOWED_ORIGINS[0] == "*")
app = FastAPI(title="Image Captioning API (BLIP2 / CogVLM)")
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=ALLOW_CREDENTIALS,
allow_methods=["*"],
allow_headers=["*"],
)
# ── (JednoduchΓ©) Bearer auth pΕ™es tajnΓ½ token v headeru Authorization ----------
API_TOKEN = os.environ.get("VITE_API_TOKEN") # nastav v HF Space Settings β†’ Secrets
def check_auth(auth_header: Optional[str]):
if API_TOKEN:
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing Bearer token")
token = auth_header.split(" ", 1)[1]
if token != API_TOKEN:
raise HTTPException(status_code=403, detail="Invalid token")
# ── Model registry (thread-safe lazy loading) ─────────────────────────────────
MODELS: Dict[str, Any] = {"blip2": None, "cogvlm": None}
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CACHE_DIR = os.environ.get("HF_HOME", "/app/.cache/huggingface")
# Thread-safe locks pro model loading
_loading_locks = {"blip2": threading.Lock(), "cogvlm": threading.Lock()}
def remove_prompt_from_output(full_text: str, original_prompt: str) -> str:
"""LepΕ‘Γ­ odstranΔ›nΓ­ promptu z vΓ½stupu modelu"""
if not original_prompt:
return full_text
# Zkus exact match
if full_text.startswith(original_prompt):
return full_text[len(original_prompt):].strip()
# Zkus case-insensitive
if full_text.lower().startswith(original_prompt.lower()):
return full_text[len(original_prompt):].strip()
# Zkus s rΕ―znΓ½mi whitespace
prompt_normalized = " ".join(original_prompt.split())
text_normalized = " ".join(full_text.split())
if text_normalized.startswith(prompt_normalized):
# Najdi pozici v pΕ―vodnΓ­m textu
words_to_remove = len(original_prompt.split())
remaining_words = full_text.split()[words_to_remove:]
return " ".join(remaining_words).strip()
# Fallback: vraΕ₯ celΓ½ text
return full_text
def load_blip2():
try:
from transformers import Blip2Processor, Blip2ForConditionalGeneration
name = os.environ.get("BLIP2_NAME", "Salesforce/blip2-opt-2.7b")
processor = Blip2Processor.from_pretrained(name, cache_dir=CACHE_DIR)
model = Blip2ForConditionalGeneration.from_pretrained(
name,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
cache_dir=CACHE_DIR
).to(DEVICE)
return {"name": name, "processor": processor, "model": model}
except Exception as e:
# Cleanup pΕ™i chybΔ›
if DEVICE == "cuda":
torch.cuda.empty_cache()
raise HTTPException(500, detail=f"Failed to load BLIP2: {str(e)}")
def caption_blip2(image: Image.Image, prompt: Optional[str], max_new_tokens: int):
# Thread-safe model loading
if MODELS["blip2"] is None:
with _loading_locks["blip2"]:
if MODELS["blip2"] is None: # Double-check
MODELS["blip2"] = load_blip2()
entry = MODELS["blip2"]
processor = entry["processor"]
model = entry["model"]
text = prompt or "Describe this image."
try:
inputs = processor(images=image, text=text, return_tensors="pt").to(DEVICE)
with torch.no_grad():
output_ids = model.generate(**inputs, max_new_tokens=128)
# BLIP2 vracΓ­ celou sekvenci včetnΔ› input promptu
full_text = processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
# Odebereme vstupnΓ­ prompt z vΓ½sledku
caption = remove_prompt_from_output(full_text, text)
return caption if caption else full_text
except Exception as e:
raise HTTPException(500, detail=f"BLIP2 generation failed: {str(e)}")
def load_cogvlm():
try:
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
name = os.environ.get("COGVLM_NAME", "THUDM/cogvlm2-llama3-captioner")
processor = AutoProcessor.from_pretrained(name, trust_remote_code=True, cache_dir=CACHE_DIR)
tokenizer = AutoTokenizer.from_pretrained(
name,
trust_remote_code=True,
use_fast=False, # CogVLM mΕ―ΕΎe mΓ­t problΓ©my s fast tokenizerem
cache_dir=CACHE_DIR
)
# Konzistentní device handling - použij buď DEVICE nebo device_map="auto", ne oboje
if DEVICE == "cuda" and torch.cuda.device_count() > 1:
# Multi-GPU setup
model = AutoModelForCausalLM.from_pretrained(
name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
cache_dir=CACHE_DIR
)
else:
# Single GPU/CPU setup
model = AutoModelForCausalLM.from_pretrained(
name,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
trust_remote_code=True,
cache_dir=CACHE_DIR
).to(DEVICE)
return {"name": name, "processor": processor, "tokenizer": tokenizer, "model": model}
except Exception as e:
# Cleanup pΕ™i chybΔ›
if DEVICE == "cuda":
torch.cuda.empty_cache()
raise HTTPException(500, detail=f"Failed to load CogVLM: {str(e)}")
def caption_cogvlm(image: Image.Image, prompt: Optional[str], max_new_tokens: int):
# Thread-safe model loading
if MODELS["cogvlm"] is None:
with _loading_locks["cogvlm"]:
if MODELS["cogvlm"] is None: # Double-check
MODELS["cogvlm"] = load_cogvlm()
entry = MODELS["cogvlm"]
processor = entry["processor"]
tokenizer = entry["tokenizer"]
model = entry["model"]
text = prompt or "Describe this image."
try:
# PouΕΎij konzistentnΓ­ device
target_device = model.device if hasattr(model, 'device') else DEVICE
inputs = processor(images=image, text=text, return_tensors="pt").to(target_device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
# CogVLM takΓ© mΕ―ΕΎe vracet celou sekvenci včetnΔ› promptu
full_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
# Odebereme vstupnΓ­ prompt z vΓ½sledku
caption = remove_prompt_from_output(full_text, text)
return caption if caption else full_text
except Exception as e:
raise HTTPException(500, detail=f"CogVLM generation failed: {str(e)}")
# ── Routes ────────────────────────────────────────────────────────────────────
@app.get("/")
def root():
return {
"message": "Image Captioning API (BLIP2 / CogVLM)",
"endpoints": ["/health", "/caption"],
"device": DEVICE,
"models": list(MODELS.keys()),
"cuda_available": torch.cuda.is_available(),
"cuda_devices": torch.cuda.device_count() if torch.cuda.is_available() else 0
}
@app.get("/health")
def health():
return {
"status": "ok",
"device": DEVICE,
"cuda": torch.cuda.is_available(),
"loaded_models": [k for k, v in MODELS.items() if v is not None]
}
@app.get("/caption")
def caption_info():
return {
"method": "POST",
"description": "Upload image and get caption",
"parameters": {
"file": "image file (required)",
"model": "blip2 or cogvlm (default: blip2)",
"prompt": "custom prompt (optional)",
"max_new_tokens": "max tokens to generate (default: 64)"
},
"auth": "Bearer token in Authorization header (if API_TOKEN is set)"
}
@app.post("/caption")
async def caption(
file: UploadFile = File(...),
model: str = Form("blip2"),
prompt: Optional[str] = Form(None),
max_new_tokens: int = Form(128),
authorization: Optional[str] = Header(None)
):
check_auth(authorization)
if model not in ("blip2", "cogvlm"):
raise HTTPException(400, detail="model must be 'blip2' or 'cogvlm'")
if max_new_tokens < 1 or max_new_tokens > 512:
raise HTTPException(400, detail="max_new_tokens must be between 1 and 512")
try:
content = await file.read()
image = Image.open(io.BytesIO(content)).convert("RGB")
except Exception as e:
raise HTTPException(400, detail=f"Invalid image file: {str(e)}")
try:
if model == "blip2":
caption_text = caption_blip2(image, prompt, max_new_tokens)
else:
caption_text = caption_cogvlm(image, prompt, max_new_tokens)
return JSONResponse({
"model": model,
"caption": caption_text,
"prompt_used": prompt or "Describe this image.",
"max_new_tokens": max_new_tokens
})
except HTTPException:
# Re-raise HTTP exceptions (uΕΎ majΓ­ sprΓ‘vnΓ½ status code)
raise
except Exception as e:
# Catch-all pro neočekÑvané chyby
raise HTTPException(500, detail=f"Caption generation failed: {str(e)}")