MalikSahib1's picture
Update app.py
5be7e14 verified
# Import necessary libraries
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
import torch
from diffusers import StableDiffusionPipeline
from pydantic import BaseModel
import io
import base64
# Pydantic model for the request body
class ImageRequest(BaseModel):
prompt: str
# Initialize FastAPI app
app = FastAPI()
# Add CORS middleware to allow cross-origin requests from your Netlify website
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# --- MODEL LOADING SECTION (UPDATED FOR LOW MEMORY) ---
# We use a very small and efficient model: segmind/tiny-sd
# This is the best choice for a free, memory-constrained environment.
print("Loading model...")
pipe = StableDiffusionPipeline.from_pretrained(
"segmind/tiny-sd",
torch_dtype=torch.float32 # Use float32 for CPU for better stability
)
# This is a key optimization for low-memory environments.
# It processes the image in parts ("slices") instead of all at once.
pipe.enable_attention_slicing()
print("Model loaded successfully!")
# Define the API endpoint
@app.post("/generate-image")
async def generate_image(request: ImageRequest):
try:
prompt = request.prompt
print(f"Generating image for prompt: {prompt}")
# Generate the image
# tiny-sd works well with around 25 steps.
image = pipe(prompt, num_inference_steps=25, guidance_scale=7.5).images[0]
print("Image generated.")
# Convert the image to a byte stream
buffer = io.BytesIO()
image.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Return the image as a JSON response with base64 string
return {"image_data": img_str}
except Exception as e:
# This will print the exact error to the Hugging Face logs for debugging
print(f"An error occurred during image generation: {e}")
return Response(content=f"An error occurred: {e}", status_code=500)
@app.get("/")
def read_root():
return {"Status": "API is running"}