# app.py from fastapi import FastAPI, UploadFile, File, Request, HTTPException from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import os import io import torch from PIL import Image from transformers import pipeline import asyncio # غیرفعال کردن خودکار داکز (نمای Swagger) برای کاربران: app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None) # استاتیک و تمپلیت app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") MODEL_ID = os.environ.get("MODEL_ID", "gpt2") MODEL_TYPE = os.environ.get("MODEL_TYPE", "llm").lower() TRUST_REMOTE_CODE = os.environ.get("TRUST_REMOTE_CODE", "false").lower() in ("1","true","yes") pipe = None load_error = None def get_device(): return 0 if torch.cuda.is_available() else -1 @app.on_event("startup") def load_model(): global pipe, load_error try: if MODEL_TYPE == "vlm": pipe = pipeline("image-to-text", model=MODEL_ID, device=get_device(), trust_remote_code=TRUST_REMOTE_CODE) else: pipe = pipeline("text-generation", model=MODEL_ID, device=get_device(), trust_remote_code=TRUST_REMOTE_CODE) except Exception as e: load_error = str(e) async def run_blocking(func, *args, **kwargs): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) # صفحهٔ اصلی برای کاربران — فرمِ ساده @app.get("/", response_class=HTMLResponse) def index(request: Request): return templates.TemplateResponse("index.html", {"request": request, "model_type": MODEL_TYPE, "model_id": MODEL_ID, "model_loaded": pipe is not None}) # generate-text (همانند قبل، اما async) @app.post("/generate-text") async def generate_text(payload: dict): if MODEL_TYPE == "vlm": raise HTTPException(status_code=400, detail="Model is VLM. Use /image-caption.") if pipe is None: raise HTTPException(status_code=503, detail=f"Model not loaded: {load_error or 'loading'}") prompt = payload.get("prompt", "") max_new_tokens = payload.get("max_new_tokens", 64) do_sample = payload.get("do_sample", False) temperature = payload.get("temperature", 0.7) outputs = await run_blocking(pipe, prompt, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, return_full_text=False) if isinstance(outputs, list) and outputs: text_out = outputs[0].get("generated_text") or outputs[0].get("text") or str(outputs[0]) else: text_out = str(outputs) return {"generated_text": text_out} # image-caption @app.post("/image-caption") async def image_caption(file: UploadFile = File(...)): if MODEL_TYPE != "vlm": raise HTTPException(status_code=400, detail="Model is LLM. Set MODEL_TYPE=vlm.") if pipe is None: raise HTTPException(status_code=503, detail=f"Model not loaded: {load_error or 'loading'}") contents = await file.read() try: img = Image.open(io.BytesIO(contents)).convert("RGB") except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid image: {e}") outputs = await run_blocking(pipe, img) if isinstance(outputs, list) and outputs: caption = outputs[0].get("generated_text") or outputs[0].get("caption") or str(outputs[0]) else: caption = str(outputs) return {"caption": caption}