File size: 2,923 Bytes
7237554
 
 
 
 
 
 
 
 
61a06ba
 
 
 
 
 
e912f66
7237554
 
e912f66
7237554
 
 
 
 
 
 
 
 
e912f66
7237554
 
e912f66
7237554
 
61a06ba
 
 
 
 
7237554
e912f66
7237554
 
 
e912f66
7237554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e912f66
7237554
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from diffusers import StableDiffusionPipeline
from fastapi import FastAPI, Response
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import io
import base64
from typing import Optional
import uvicorn
import os

# Set cache directories to writable locations
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
os.environ["HF_HOME"] = "/tmp/hf_home"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets_cache"

# Initialize FastAPI app
app = FastAPI(title="Stable Diffusion API")

# Define input model
class TextToImageRequest(BaseModel):
    prompt: str
    negative_prompt: Optional[str] = None
    num_inference_steps: Optional[int] = 50
    guidance_scale: Optional[float] = 7.5
    height: Optional[int] = 512
    width: Optional[int] = 512
    seed: Optional[int] = None

# Load the model (will be loaded when the Space is initialized)
model_id = "CompVis/stable-diffusion-v1-4"

# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(
    model_id, 
    cache_dir="/tmp/diffusers_cache",
    token=os.environ.get("HF_TOKEN")  # Use token from environment variable
)
pipe = pipe.to(device)

@app.get("/")
def read_root():
    return {"message": "Stable Diffusion API is running. Use POST /generate endpoint."}

@app.post("/generate")
async def generate_image(request: TextToImageRequest):
    try:
        # Set seed if provided
        if request.seed is not None:
            generator = torch.Generator(device=device).manual_seed(request.seed)
        else:
            generator = None
        
        # Generate image
        image = pipe(
            prompt=request.prompt,
            negative_prompt=request.negative_prompt,
            num_inference_steps=request.num_inference_steps,
            guidance_scale=request.guidance_scale,
            height=request.height,
            width=request.width,
            generator=generator
        ).images[0]
        
        # Convert to base64
        buffer = io.BytesIO()
        image.save(buffer, format="PNG")
        img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
        
        return JSONResponse({
            "status": "success",
            "image": img_str,
            "parameters": {
                "prompt": request.prompt,
                "negative_prompt": request.negative_prompt,
                "steps": request.num_inference_steps,
                "guidance_scale": request.guidance_scale,
                "dimensions": f"{request.width}x{request.height}",
                "seed": request.seed
            }
        })
    
    except Exception as e:
        return JSONResponse(
            status_code=500,
            content={"status": "error", "message": str(e)}
        )

# For local testing, not necessary in Spaces
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)