My-Img-Space / app.py
Valtry's picture
Update app.py
cf9dbf6 verified
import os
from fastapi import FastAPI
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from diffusers import StableDiffusionPipeline
import torch
from io import BytesIO
import base64
import logging
# ---------- Configuration ----------
# ensure HF caches are on a writable path inside the container/Space
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
os.environ.setdefault("HF_HOME", "/tmp/hf")
os.environ.setdefault("HF_DATASETS_CACHE", "/tmp/hf/datasets")
os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf/modules")
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/hf/xdg")
MODEL_ID = "Valtry/My-Img" # change if necessary
USE_LOW_CPU_MEM = True # requires accelerate installed (recommended)
# ---------- App and static ----------
app = FastAPI()
# Serve static files from ./static, index.html will be shown at "/"
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
async def index():
# serve the static index — using FileResponse so API routes don't get overridden
return FileResponse("static/index.html")
class PromptRequest(BaseModel):
prompt: str
# ---------- Model loading ----------
logger = logging.getLogger("uvicorn.error")
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Device detected: {device}")
pipe = None
def load_pipeline():
global pipe
try:
torch_dtype = torch.float16 if device == "cuda" else torch.float32
load_kwargs = {
"torch_dtype": torch_dtype,
"cache_dir": "/tmp/hf",
}
# low_cpu_mem_usage only valid if accelerate is available; pass if wanted
if USE_LOW_CPU_MEM:
load_kwargs["low_cpu_mem_usage"] = True
logger.info(f"Loading pipeline {MODEL_ID} with kwargs: { {k:v for k,v in load_kwargs.items() if k!='cache_dir'} }")
pipe_local = StableDiffusionPipeline.from_pretrained(MODEL_ID, **load_kwargs)
# to() on the UNet, or whole pipeline:
if device == "cuda":
pipe_local = pipe_local.to("cuda")
else:
pipe_local = pipe_local.to("cpu")
logger.info("Model loaded successfully")
pipe = pipe_local
return True, None
except Exception as e:
logger.exception("Failed to load pipeline")
return False, str(e)
# Attempt load at startup
_success, _err = load_pipeline()
if not _success:
logger.error(f"Initial model load failed: {_err}")
# ---------- Endpoint ----------
@app.post("/generate")
async def generate(prompt_request: PromptRequest):
global pipe
if pipe is None:
# try to reload once
ok, err = load_pipeline()
if not ok:
return JSONResponse({"error": "Model not loaded", "details": err}, status_code=500)
prompt = prompt_request.prompt or ""
if not prompt.strip():
return JSONResponse({"error": "Prompt is empty"}, status_code=400)
try:
# Generate
# If you want to expose additional options (steps, guidance_scale), expand this.
out = pipe(prompt)
image = out.images[0]
# encode to base64 and return
buffer = BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
img_b64 = base64.b64encode(buffer.read()).decode("utf-8")
return JSONResponse({"image": img_b64})
except Exception as e:
logger.exception("Generation failed")
return JSONResponse({"error": "Generation failed", "details": str(e)}, status_code=500)