STYLARCH / app.py
Top-G-420's picture
Update app.py
d227da5 verified
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from diffusers import StableDiffusionPipeline
from io import BytesIO
from PIL import Image
import base64
import torch
app = FastAPI(
title="Floor Plan Generator API",
description="Generates clean architectural floor plans from text prompts using maria26/Floor_Plan_LoRA on Stable Diffusion.",
version="1.0"
)
# Add CORS middleware (allows your Vercel frontend to call the API)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for public demo (restrict to your Vercel URL in production)
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def root():
return {"message": "Floor Plan API live ๐Ÿš€", "docs": "/docs"}
# Device setup (works on CPU or GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load base Stable Diffusion v1.5 + LoRA
print("Loading Stable Diffusion model...")
generator = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=dtype,
safety_checker=None, # No NSFW risk for floor plans
).to(device)
# Apply the Floor Plan LoRA
generator.load_lora_weights("maria26/Floor_Plan_LoRA")
# Optimizations for speed and memory (works great on CPU too)
generator.enable_attention_slicing()
if device == "cuda":
generator.enable_model_cpu_offload() # Helps on limited GPU VRAM
print("Model loaded successfully!")
@app.post("/generate")
async def generate(request: Request):
try:
data = await request.json()
text = data.get("prompt")
if not text or not text.strip():
return JSONResponse({"error": "prompt is required and cannot be empty"}, status_code=400)
# Strong negative prompt for clean black-and-white architectural plans
negative_prompt = "blurry, low quality, text, letters, numbers, watermark, people, furniture, colored, photorealistic, 3d render"
image = generator(
prompt=text,
negative_prompt=negative_prompt,
num_inference_steps=30, # Good balance (increase to 50 for better quality on CPU)
guidance_scale=9.0, # Strong prompt adherence
height=512,
width=512,
).images[0]
buffer = BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
return {"image": encoded}
except Exception as e:
print(f"Generation error: {str(e)}")
return JSONResponse({"error": str(e)}, status_code=500)