image-to-text / app.py
moslem's picture
Update app.py
bec1bb2 verified
# app.py
from fastapi import FastAPI, UploadFile, File, Request, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import os
import io
import torch
from PIL import Image
from transformers import pipeline
import asyncio
# غیرفعال کردن خودکار داکز (نمای Swagger) برای کاربران:
app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
# استاتیک و تمپلیت
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
MODEL_ID = os.environ.get("MODEL_ID", "gpt2")
MODEL_TYPE = os.environ.get("MODEL_TYPE", "llm").lower()
TRUST_REMOTE_CODE = os.environ.get("TRUST_REMOTE_CODE", "false").lower() in ("1","true","yes")
pipe = None
load_error = None
def get_device():
return 0 if torch.cuda.is_available() else -1
@app.on_event("startup")
def load_model():
global pipe, load_error
try:
if MODEL_TYPE == "vlm":
pipe = pipeline("image-to-text", model=MODEL_ID, device=get_device(), trust_remote_code=TRUST_REMOTE_CODE)
else:
pipe = pipeline("text-generation", model=MODEL_ID, device=get_device(), trust_remote_code=TRUST_REMOTE_CODE)
except Exception as e:
load_error = str(e)
async def run_blocking(func, *args, **kwargs):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
# صفحهٔ اصلی برای کاربران — فرمِ ساده
@app.get("/", response_class=HTMLResponse)
def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request, "model_type": MODEL_TYPE, "model_id": MODEL_ID, "model_loaded": pipe is not None})
# generate-text (همانند قبل، اما async)
@app.post("/generate-text")
async def generate_text(payload: dict):
if MODEL_TYPE == "vlm":
raise HTTPException(status_code=400, detail="Model is VLM. Use /image-caption.")
if pipe is None:
raise HTTPException(status_code=503, detail=f"Model not loaded: {load_error or 'loading'}")
prompt = payload.get("prompt", "")
max_new_tokens = payload.get("max_new_tokens", 64)
do_sample = payload.get("do_sample", False)
temperature = payload.get("temperature", 0.7)
outputs = await run_blocking(pipe, prompt, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, return_full_text=False)
if isinstance(outputs, list) and outputs:
text_out = outputs[0].get("generated_text") or outputs[0].get("text") or str(outputs[0])
else:
text_out = str(outputs)
return {"generated_text": text_out}
# image-caption
@app.post("/image-caption")
async def image_caption(file: UploadFile = File(...)):
if MODEL_TYPE != "vlm":
raise HTTPException(status_code=400, detail="Model is LLM. Set MODEL_TYPE=vlm.")
if pipe is None:
raise HTTPException(status_code=503, detail=f"Model not loaded: {load_error or 'loading'}")
contents = await file.read()
try:
img = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image: {e}")
outputs = await run_blocking(pipe, img)
if isinstance(outputs, list) and outputs:
caption = outputs[0].get("generated_text") or outputs[0].get("caption") or str(outputs[0])
else:
caption = str(outputs)
return {"caption": caption}