|
|
from fastapi import FastAPI, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
from diffusers import StableDiffusionPipeline |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import base64 |
|
|
import torch |
|
|
|
|
|
app = FastAPI( |
|
|
title="Floor Plan Generator API", |
|
|
description="Generates clean architectural floor plans from text prompts using maria26/Floor_Plan_LoRA on Stable Diffusion.", |
|
|
version="1.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return {"message": "Floor Plan API live ๐", "docs": "/docs"} |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
|
|
|
print("Loading Stable Diffusion model...") |
|
|
generator = StableDiffusionPipeline.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
torch_dtype=dtype, |
|
|
safety_checker=None, |
|
|
).to(device) |
|
|
|
|
|
|
|
|
generator.load_lora_weights("maria26/Floor_Plan_LoRA") |
|
|
|
|
|
|
|
|
generator.enable_attention_slicing() |
|
|
if device == "cuda": |
|
|
generator.enable_model_cpu_offload() |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate(request: Request): |
|
|
try: |
|
|
data = await request.json() |
|
|
text = data.get("prompt") |
|
|
if not text or not text.strip(): |
|
|
return JSONResponse({"error": "prompt is required and cannot be empty"}, status_code=400) |
|
|
|
|
|
|
|
|
negative_prompt = "blurry, low quality, text, letters, numbers, watermark, people, furniture, colored, photorealistic, 3d render" |
|
|
|
|
|
image = generator( |
|
|
prompt=text, |
|
|
negative_prompt=negative_prompt, |
|
|
num_inference_steps=30, |
|
|
guidance_scale=9.0, |
|
|
height=512, |
|
|
width=512, |
|
|
).images[0] |
|
|
|
|
|
buffer = BytesIO() |
|
|
image.save(buffer, format="PNG") |
|
|
buffer.seek(0) |
|
|
encoded = base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
|
|
|
return {"image": encoded} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Generation error: {str(e)}") |
|
|
return JSONResponse({"error": str(e)}, status_code=500) |