| | |
| | import io, os, threading |
| | from typing import Optional, Any, Dict |
| | from PIL import Image |
| |
|
| | import torch |
| | from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Header |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import JSONResponse |
| |
|
| | |
| | ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "*").split(",") |
| | ALLOW_CREDENTIALS = not (len(ALLOWED_ORIGINS) == 1 and ALLOWED_ORIGINS[0] == "*") |
| | app = FastAPI(title="Image Captioning API (BLIP2 / CogVLM)") |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=ALLOWED_ORIGINS, |
| | allow_credentials=ALLOW_CREDENTIALS, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | API_TOKEN = os.environ.get("VITE_API_TOKEN") |
| |
|
| | def check_auth(auth_header: Optional[str]): |
| | if API_TOKEN: |
| | if not auth_header or not auth_header.startswith("Bearer "): |
| | raise HTTPException(status_code=401, detail="Missing Bearer token") |
| | token = auth_header.split(" ", 1)[1] |
| | if token != API_TOKEN: |
| | raise HTTPException(status_code=403, detail="Invalid token") |
| |
|
| | |
| | MODELS: Dict[str, Any] = {"blip2": None, "cogvlm": None} |
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| | CACHE_DIR = os.environ.get("HF_HOME", "/app/.cache/huggingface") |
| |
|
| | |
| | _loading_locks = {"blip2": threading.Lock(), "cogvlm": threading.Lock()} |
| |
|
| | def remove_prompt_from_output(full_text: str, original_prompt: str) -> str: |
| | """LepΕ‘Γ odstranΔnΓ promptu z vΓ½stupu modelu""" |
| | if not original_prompt: |
| | return full_text |
| |
|
| | |
| | if full_text.startswith(original_prompt): |
| | return full_text[len(original_prompt):].strip() |
| |
|
| | |
| | if full_text.lower().startswith(original_prompt.lower()): |
| | return full_text[len(original_prompt):].strip() |
| |
|
| | |
| | prompt_normalized = " ".join(original_prompt.split()) |
| | text_normalized = " ".join(full_text.split()) |
| | if text_normalized.startswith(prompt_normalized): |
| | |
| | words_to_remove = len(original_prompt.split()) |
| | remaining_words = full_text.split()[words_to_remove:] |
| | return " ".join(remaining_words).strip() |
| |
|
| | |
| | return full_text |
| |
|
| | def load_blip2(): |
| | try: |
| | from transformers import Blip2Processor, Blip2ForConditionalGeneration |
| | name = os.environ.get("BLIP2_NAME", "Salesforce/blip2-opt-2.7b") |
| | processor = Blip2Processor.from_pretrained(name, cache_dir=CACHE_DIR) |
| | model = Blip2ForConditionalGeneration.from_pretrained( |
| | name, |
| | torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
| | cache_dir=CACHE_DIR |
| | ).to(DEVICE) |
| | return {"name": name, "processor": processor, "model": model} |
| | except Exception as e: |
| | |
| | if DEVICE == "cuda": |
| | torch.cuda.empty_cache() |
| | raise HTTPException(500, detail=f"Failed to load BLIP2: {str(e)}") |
| |
|
| | def caption_blip2(image: Image.Image, prompt: Optional[str], max_new_tokens: int): |
| | |
| | if MODELS["blip2"] is None: |
| | with _loading_locks["blip2"]: |
| | if MODELS["blip2"] is None: |
| | MODELS["blip2"] = load_blip2() |
| |
|
| | entry = MODELS["blip2"] |
| | processor = entry["processor"] |
| | model = entry["model"] |
| | text = prompt or "Describe this image." |
| |
|
| | try: |
| | inputs = processor(images=image, text=text, return_tensors="pt").to(DEVICE) |
| | with torch.no_grad(): |
| | output_ids = model.generate(**inputs, max_new_tokens=128) |
| |
|
| | |
| | full_text = processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
| |
|
| | |
| | caption = remove_prompt_from_output(full_text, text) |
| |
|
| | return caption if caption else full_text |
| | except Exception as e: |
| | raise HTTPException(500, detail=f"BLIP2 generation failed: {str(e)}") |
| |
|
| | def load_cogvlm(): |
| | try: |
| | from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM |
| | name = os.environ.get("COGVLM_NAME", "THUDM/cogvlm2-llama3-captioner") |
| | processor = AutoProcessor.from_pretrained(name, trust_remote_code=True, cache_dir=CACHE_DIR) |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | name, |
| | trust_remote_code=True, |
| | use_fast=False, |
| | cache_dir=CACHE_DIR |
| | ) |
| |
|
| | |
| | if DEVICE == "cuda" and torch.cuda.device_count() > 1: |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | name, |
| | torch_dtype=torch.float16, |
| | device_map="auto", |
| | trust_remote_code=True, |
| | cache_dir=CACHE_DIR |
| | ) |
| | else: |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | name, |
| | torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
| | trust_remote_code=True, |
| | cache_dir=CACHE_DIR |
| | ).to(DEVICE) |
| |
|
| | return {"name": name, "processor": processor, "tokenizer": tokenizer, "model": model} |
| | except Exception as e: |
| | |
| | if DEVICE == "cuda": |
| | torch.cuda.empty_cache() |
| | raise HTTPException(500, detail=f"Failed to load CogVLM: {str(e)}") |
| |
|
| | def caption_cogvlm(image: Image.Image, prompt: Optional[str], max_new_tokens: int): |
| | |
| | if MODELS["cogvlm"] is None: |
| | with _loading_locks["cogvlm"]: |
| | if MODELS["cogvlm"] is None: |
| | MODELS["cogvlm"] = load_cogvlm() |
| |
|
| | entry = MODELS["cogvlm"] |
| | processor = entry["processor"] |
| | tokenizer = entry["tokenizer"] |
| | model = entry["model"] |
| | text = prompt or "Describe this image." |
| |
|
| | try: |
| | |
| | target_device = model.device if hasattr(model, 'device') else DEVICE |
| | inputs = processor(images=image, text=text, return_tensors="pt").to(target_device) |
| |
|
| | with torch.no_grad(): |
| | output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) |
| |
|
| | |
| | full_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip() |
| |
|
| | |
| | caption = remove_prompt_from_output(full_text, text) |
| |
|
| | return caption if caption else full_text |
| | except Exception as e: |
| | raise HTTPException(500, detail=f"CogVLM generation failed: {str(e)}") |
| |
|
| | |
| | @app.get("/") |
| | def root(): |
| | return { |
| | "message": "Image Captioning API (BLIP2 / CogVLM)", |
| | "endpoints": ["/health", "/caption"], |
| | "device": DEVICE, |
| | "models": list(MODELS.keys()), |
| | "cuda_available": torch.cuda.is_available(), |
| | "cuda_devices": torch.cuda.device_count() if torch.cuda.is_available() else 0 |
| | } |
| |
|
| | @app.get("/health") |
| | def health(): |
| | return { |
| | "status": "ok", |
| | "device": DEVICE, |
| | "cuda": torch.cuda.is_available(), |
| | "loaded_models": [k for k, v in MODELS.items() if v is not None] |
| | } |
| |
|
| | @app.get("/caption") |
| | def caption_info(): |
| | return { |
| | "method": "POST", |
| | "description": "Upload image and get caption", |
| | "parameters": { |
| | "file": "image file (required)", |
| | "model": "blip2 or cogvlm (default: blip2)", |
| | "prompt": "custom prompt (optional)", |
| | "max_new_tokens": "max tokens to generate (default: 64)" |
| | }, |
| | "auth": "Bearer token in Authorization header (if API_TOKEN is set)" |
| | } |
| |
|
| | @app.post("/caption") |
| | async def caption( |
| | file: UploadFile = File(...), |
| | model: str = Form("blip2"), |
| | prompt: Optional[str] = Form(None), |
| | max_new_tokens: int = Form(128), |
| | authorization: Optional[str] = Header(None) |
| | ): |
| | check_auth(authorization) |
| |
|
| | if model not in ("blip2", "cogvlm"): |
| | raise HTTPException(400, detail="model must be 'blip2' or 'cogvlm'") |
| |
|
| | if max_new_tokens < 1 or max_new_tokens > 512: |
| | raise HTTPException(400, detail="max_new_tokens must be between 1 and 512") |
| |
|
| | try: |
| | content = await file.read() |
| | image = Image.open(io.BytesIO(content)).convert("RGB") |
| | except Exception as e: |
| | raise HTTPException(400, detail=f"Invalid image file: {str(e)}") |
| |
|
| | try: |
| | if model == "blip2": |
| | caption_text = caption_blip2(image, prompt, max_new_tokens) |
| | else: |
| | caption_text = caption_cogvlm(image, prompt, max_new_tokens) |
| |
|
| | return JSONResponse({ |
| | "model": model, |
| | "caption": caption_text, |
| | "prompt_used": prompt or "Describe this image.", |
| | "max_new_tokens": max_new_tokens |
| | }) |
| | except HTTPException: |
| | |
| | raise |
| | except Exception as e: |
| | |
| | raise HTTPException(500, detail=f"Caption generation failed: {str(e)}") |
| |
|