Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| from io import BytesIO | |
| import base64 | |
| import logging | |
| # ---------- Configuration ---------- | |
| # ensure HF caches are on a writable path inside the container/Space | |
| os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers") | |
| os.environ.setdefault("HF_HOME", "/tmp/hf") | |
| os.environ.setdefault("HF_DATASETS_CACHE", "/tmp/hf/datasets") | |
| os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf/modules") | |
| os.environ.setdefault("XDG_CACHE_HOME", "/tmp/hf/xdg") | |
| MODEL_ID = "Valtry/My-Img" # change if necessary | |
| USE_LOW_CPU_MEM = True # requires accelerate installed (recommended) | |
| # ---------- App and static ---------- | |
| app = FastAPI() | |
| # Serve static files from ./static, index.html will be shown at "/" | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| async def index(): | |
| # serve the static index — using FileResponse so API routes don't get overridden | |
| return FileResponse("static/index.html") | |
| class PromptRequest(BaseModel): | |
| prompt: str | |
| # ---------- Model loading ---------- | |
| logger = logging.getLogger("uvicorn.error") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Device detected: {device}") | |
| pipe = None | |
| def load_pipeline(): | |
| global pipe | |
| try: | |
| torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| load_kwargs = { | |
| "torch_dtype": torch_dtype, | |
| "cache_dir": "/tmp/hf", | |
| } | |
| # low_cpu_mem_usage only valid if accelerate is available; pass if wanted | |
| if USE_LOW_CPU_MEM: | |
| load_kwargs["low_cpu_mem_usage"] = True | |
| logger.info(f"Loading pipeline {MODEL_ID} with kwargs: { {k:v for k,v in load_kwargs.items() if k!='cache_dir'} }") | |
| pipe_local = StableDiffusionPipeline.from_pretrained(MODEL_ID, **load_kwargs) | |
| # to() on the UNet, or whole pipeline: | |
| if device == "cuda": | |
| pipe_local = pipe_local.to("cuda") | |
| else: | |
| pipe_local = pipe_local.to("cpu") | |
| logger.info("Model loaded successfully") | |
| pipe = pipe_local | |
| return True, None | |
| except Exception as e: | |
| logger.exception("Failed to load pipeline") | |
| return False, str(e) | |
| # Attempt load at startup | |
| _success, _err = load_pipeline() | |
| if not _success: | |
| logger.error(f"Initial model load failed: {_err}") | |
| # ---------- Endpoint ---------- | |
| async def generate(prompt_request: PromptRequest): | |
| global pipe | |
| if pipe is None: | |
| # try to reload once | |
| ok, err = load_pipeline() | |
| if not ok: | |
| return JSONResponse({"error": "Model not loaded", "details": err}, status_code=500) | |
| prompt = prompt_request.prompt or "" | |
| if not prompt.strip(): | |
| return JSONResponse({"error": "Prompt is empty"}, status_code=400) | |
| try: | |
| # Generate | |
| # If you want to expose additional options (steps, guidance_scale), expand this. | |
| out = pipe(prompt) | |
| image = out.images[0] | |
| # encode to base64 and return | |
| buffer = BytesIO() | |
| image.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| img_b64 = base64.b64encode(buffer.read()).decode("utf-8") | |
| return JSONResponse({"image": img_b64}) | |
| except Exception as e: | |
| logger.exception("Generation failed") | |
| return JSONResponse({"error": "Generation failed", "details": str(e)}, status_code=500) |