Spaces:
Runtime error
Runtime error
| # app.py | |
| from fastapi import FastAPI, UploadFile, File, Request, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| import os | |
| import io | |
| import torch | |
| from PIL import Image | |
| from transformers import pipeline | |
| import asyncio | |
| # غیرفعال کردن خودکار داکز (نمای Swagger) برای کاربران: | |
| app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None) | |
| # استاتیک و تمپلیت | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| MODEL_ID = os.environ.get("MODEL_ID", "gpt2") | |
| MODEL_TYPE = os.environ.get("MODEL_TYPE", "llm").lower() | |
| TRUST_REMOTE_CODE = os.environ.get("TRUST_REMOTE_CODE", "false").lower() in ("1","true","yes") | |
| pipe = None | |
| load_error = None | |
| def get_device(): | |
| return 0 if torch.cuda.is_available() else -1 | |
| def load_model(): | |
| global pipe, load_error | |
| try: | |
| if MODEL_TYPE == "vlm": | |
| pipe = pipeline("image-to-text", model=MODEL_ID, device=get_device(), trust_remote_code=TRUST_REMOTE_CODE) | |
| else: | |
| pipe = pipeline("text-generation", model=MODEL_ID, device=get_device(), trust_remote_code=TRUST_REMOTE_CODE) | |
| except Exception as e: | |
| load_error = str(e) | |
| async def run_blocking(func, *args, **kwargs): | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) | |
| # صفحهٔ اصلی برای کاربران — فرمِ ساده | |
| def index(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request, "model_type": MODEL_TYPE, "model_id": MODEL_ID, "model_loaded": pipe is not None}) | |
| # generate-text (همانند قبل، اما async) | |
| async def generate_text(payload: dict): | |
| if MODEL_TYPE == "vlm": | |
| raise HTTPException(status_code=400, detail="Model is VLM. Use /image-caption.") | |
| if pipe is None: | |
| raise HTTPException(status_code=503, detail=f"Model not loaded: {load_error or 'loading'}") | |
| prompt = payload.get("prompt", "") | |
| max_new_tokens = payload.get("max_new_tokens", 64) | |
| do_sample = payload.get("do_sample", False) | |
| temperature = payload.get("temperature", 0.7) | |
| outputs = await run_blocking(pipe, prompt, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, return_full_text=False) | |
| if isinstance(outputs, list) and outputs: | |
| text_out = outputs[0].get("generated_text") or outputs[0].get("text") or str(outputs[0]) | |
| else: | |
| text_out = str(outputs) | |
| return {"generated_text": text_out} | |
| # image-caption | |
| async def image_caption(file: UploadFile = File(...)): | |
| if MODEL_TYPE != "vlm": | |
| raise HTTPException(status_code=400, detail="Model is LLM. Set MODEL_TYPE=vlm.") | |
| if pipe is None: | |
| raise HTTPException(status_code=503, detail=f"Model not loaded: {load_error or 'loading'}") | |
| contents = await file.read() | |
| try: | |
| img = Image.open(io.BytesIO(contents)).convert("RGB") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid image: {e}") | |
| outputs = await run_blocking(pipe, img) | |
| if isinstance(outputs, list) and outputs: | |
| caption = outputs[0].get("generated_text") or outputs[0].get("caption") or str(outputs[0]) | |
| else: | |
| caption = str(outputs) | |
| return {"caption": caption} | |