wallgen / backend /main.py
devankit's picture
Update backend/main.py
df66415 verified
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import uvicorn
from fastapi.staticfiles import StaticFiles
from diffusers import StableDiffusionPipeline, LCMScheduler
from PIL import Image
import numpy as np
from sklearn.cluster import KMeans
import random
import io
import base64
from typing import Optional
import os
from fastapi.staticfiles import StaticFiles
print("Model will load on first request...")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Force CPU in Docker container for consistency
device = "cpu" if os.getenv("DOCKER_CONTAINER", False) else ("cuda" if torch.cuda.is_available() else "cpu")
MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", os.path.abspath(os.path.join(BASE_DIR, "..", "models")))
pipe = None # lazy-loaded on first generation request
def load_pipeline():
global pipe
if pipe is not None:
return pipe
print("Loading model into memory...")
local_device = "cuda" if torch.cuda.is_available() else "cpu"
# Directly load from Hugging Face Hub, safetensors by default (no .bin)
p = StableDiffusionPipeline.from_pretrained(
"SimianLuo/LCM_Dreamshaper_v7",
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False,
low_cpu_mem_usage=True,
# cache_sir = "../models"
)
p.scheduler = LCMScheduler.from_config(p.scheduler.config)
p = p.to(local_device)
if local_device == "cuda":
p.enable_model_cpu_offload()
else:
p.enable_attention_slicing()
pipe = p
print("Model has been loaded successfully!")
return pipe
def extract_colors(image: Image.Image, num_colors=6):
img_array = np.array(image)
pixels = img_array.reshape(-1, 3)
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
kmeans.fit(pixels)
colors = kmeans.cluster_centers_.astype(int)
hex_colors = ['#{:02x}{:02x}{:02x}'.format(r, g, b) for r, g, b in colors]
return hex_colors
def resize_image(image: Image.Image, width: int, height: int):
return image.resize((width, height), Image.Resampling.LANCZOS)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class GenerationRequest(BaseModel):
style: str = "Gradient"
seed: int = 42
resolution: str = "Desktop (1920x1080)"
steps: int = 4
color: Optional[str] = None
@app.post("/generate-wallpaper/")
async def generate_wallpaper_endpoint(request: GenerationRequest):
prompts = {
"Geometric": "sharp geometric patterns, triangular tessellations, hexagonal grids, precise angular shapes, colorful polygon arrangements, geometric mandala designs, crystalline symmetric patterns, abstract mathematical forms, vector art precision, clean geometric compositions, bright geometric color blocks, structured pattern design, 8k wallpaper",
"Organic": "flowing organic forms, silky fluid textures, soft ethereal lighting, high-detail depth, oil painting feel, zen abstract design, bioluminescent flowing lines, translucent silk ribbons, watercolor blending, natural formations",
"Vibrant": "futuristic neon shapes, cyberpunk color palette, dark background, glowing fractals, high contrast, vivid reflections, neon geometric patterns, electric plasma effects, holographic surfaces, laser light trails, ultra bright colors",
"Minimal": "watercolor drops on white canvas, circular water color spots, bright colored water drops spreading on paper, wet watercolor bleeding circles, vibrant color drops with soft edges, colorful water stains on clean background, translucent watercolor circles, artistic water drops pattern, high contrast bright drops, clean minimalist watercolor design",
"Ink Flow": "flowing liquid ink patterns, organic ink bleeds, watercolor paint flows, dynamic brush strokes, fluid color transitions, ink spreading on wet paper, abstract paint movements, colorful ink drops merging, artistic liquid patterns, paint flow dynamics",
"Gradient": "smooth color gradients, organic flowing waves, ethereal abstract patterns, bright vivid color transitions, watercolor style blending, soft gradient meshes, flowing energy patterns, color field painting, seamless color flows, high resolution gradients",
"Crystal": "high-resolution crystal surfaces, prism reflections, iridescent colors, shattered glass texture, futuristic diamond shapes, cinematic lighting, dichroic glass effects, crystalline structures, refracting light patterns",
"Retro Wave": "synthwave aesthetic, neon grid lines, 1980s retro futuristic, purple pink gradients, chrome reflections, vintage synthesizer vibes, outrun highway aesthetic, neon palm trees, retrowave sunset",
"Botanical": "botanical illustrations, detailed leaves and flowers, vintage naturalist style, earth tones, Art Nouveau vine patterns, organic plant geometry, natural forms, herbarium specimens, botanical cross-sections",
"Space Cosmic": "deep space nebula, cosmic dust clouds, distant galaxies, purple blue cosmic colors, stellar formations, aurora borealis patterns, cosmic energy flows, nebula formations, space photography aesthetic",
"Psychedelic": "trippy psychedelic patterns, kaleidoscope effects, vibrant swirling colors, mind-bending fractals, mandala patterns, liquid light shows, synesthetic color flows, geometric portals, infinite pattern recursion",
"Industrial": "metal textures, rust patterns, industrial materials, concrete and steel aesthetics, weathered surfaces, industrial photography, high contrast lighting, urban decay patterns, metallic reflections",
"Fractal": "complex fractal spirals, infinite zoom illusion, glowing edges, depth of field blur, mathematical precision, cosmic art, Mandelbrot-inspired structure"
}
resolutions = {
"Mobile Portrait (1080x1920)": (1080, 1920),
"Desktop (1920x1080)": (1920, 1080),
"Square (1080x1080)": (1080, 1080),
"Ultrawide (2560x1080)": (2560, 1080),
"4K Desktop (3840x2160)": (3840, 2160),
}
base_prompt = prompts.get(request.style, prompts["Gradient"])
# Add quality enhancers and style variations
quality_modifiers = [
"masterpiece", "best quality", "ultra-detailed", "8k resolution",
"professional", "award-winning", "trending on artstation", "highly detailed"
]
# Randomly select 2-3 quality modifiers to add variety
selected_modifiers = random.sample(quality_modifiers, k=random.randint(2, 3))
# Build the final prompt with variations
final_prompt = f"{base_prompt}, {', '.join(selected_modifiers)}"
# Add artistic style variations occasionally
if random.random() > 0.7:
artistic_styles = [
"oil painting style", "digital art", "photorealistic", "concept art",
"matte painting", "volumetric lighting", "ray tracing", "octane render"
]
final_prompt += f", {random.choice(artistic_styles)}"
# Add composition variations
if random.random() > 0.6:
compositions = [
"rule of thirds", "golden ratio composition", "symmetrical balance",
"dynamic composition", "centered composition", "diagonal flow"
]
final_prompt += f", {random.choice(compositions)}"
if request.color and request.color.strip():
# Enhanced color integration
color_integration = random.choice([
f"dominated by {request.color.strip()} tones",
f"featuring prominent {request.color.strip()} accents",
f"with {request.color.strip()} color palette",
f"infused with {request.color.strip()} hues",
f"{request.color.strip()} color harmony"
])
final_prompt += f", {color_integration}"
# Add negative prompt to avoid common issues and maintain abstract focus
negative_prompt = "low quality, blurry, pixelated, watermark, text, logo, signature, artifacts, distorted, people, person, human, face, car, vehicle, building, realistic objects, photography, portrait, landscape"
try:
# Set seed
torch.manual_seed(request.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(request.seed)
# Generate base image
# Ensure pipeline is loaded lazily
p = load_pipeline()
with torch.no_grad():
base_image = p(
final_prompt,
num_inference_steps=request.steps,
guidance_scale=1.0,
negative_prompt=negative_prompt,
height=512,
width=512
).images[0]
# Resize image
target_width, target_height = resolutions.get(request.resolution, (1920, 1080))
resized_image = resize_image(base_image, target_width, target_height)
# Extract colors
colors = extract_colors(resized_image)
# Convert image to Base64 string to send via JSON
buffered = io.BytesIO()
resized_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return JSONResponse(content={
"image": f"data:image/png;base64,{img_str}",
"palette": colors,
"seed": request.seed,
"style": request.style,
"steps": request.steps,
"resolution": f"{target_width}x{target_height}"
})
except Exception as e:
print(f"Error during generation: {e}")
return JSONResponse(status_code=500, content={"error": str(e)})
@app.get("/random-seed/")
async def random_seed_endpoint():
return {"seed": random.randint(1, 999999)}
@app.get("/health")
async def health_check():
return {"status": "healthy", "message": "AI Wallpaper Generator API is running"}
app.mount("/", StaticFiles(directory="./frontend/build", html=True), name="static")
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
print("Starting AI Wallpaper Generator API...")
uvicorn.run(app, host="0.0.0.0", port=port)