from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from diffusers import StableDiffusionPipeline from io import BytesIO from PIL import Image import base64 import torch app = FastAPI( title="Floor Plan Generator API", description="Generates clean architectural floor plans from text prompts using maria26/Floor_Plan_LoRA on Stable Diffusion.", version="1.0" ) # Add CORS middleware (allows your Vercel frontend to call the API) app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins for public demo (restrict to your Vercel URL in production) allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") def root(): return {"message": "Floor Plan API live 🚀", "docs": "/docs"} # Device setup (works on CPU or GPU) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 # Load base Stable Diffusion v1.5 + LoRA print("Loading Stable Diffusion model...") generator = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=dtype, safety_checker=None, # No NSFW risk for floor plans ).to(device) # Apply the Floor Plan LoRA generator.load_lora_weights("maria26/Floor_Plan_LoRA") # Optimizations for speed and memory (works great on CPU too) generator.enable_attention_slicing() if device == "cuda": generator.enable_model_cpu_offload() # Helps on limited GPU VRAM print("Model loaded successfully!") @app.post("/generate") async def generate(request: Request): try: data = await request.json() text = data.get("prompt") if not text or not text.strip(): return JSONResponse({"error": "prompt is required and cannot be empty"}, status_code=400) # Strong negative prompt for clean black-and-white architectural plans negative_prompt = "blurry, low quality, text, letters, numbers, watermark, people, furniture, colored, photorealistic, 3d render" image = generator( prompt=text, negative_prompt=negative_prompt, num_inference_steps=30, # Good balance (increase to 50 for better quality on CPU) guidance_scale=9.0, # Strong prompt adherence height=512, width=512, ).images[0] buffer = BytesIO() image.save(buffer, format="PNG") buffer.seek(0) encoded = base64.b64encode(buffer.getvalue()).decode("utf-8") return {"image": encoded} except Exception as e: print(f"Generation error: {str(e)}") return JSONResponse({"error": str(e)}, status_code=500)