from flask import Flask, request, jsonify, send_file from flask_cors import CORS from diffusers import DiffusionPipeline, LCMScheduler import torch import os import json import secrets from io import BytesIO import gc from datetime import datetime import traceback app = Flask(__name__) CORS(app) # Configuration BASE = "/home/sd" WL_PATH = f"{BASE}/whitelist.txt" USAGE_PATH = f"{BASE}/usage.json" LIMITS_PATH = f"{BASE}/limits.json" DEFAULT_LIMIT = 500 # Use a fast, reliable model: LCM version for speed + quality # Alternatives: "segmind/SSD-1B" (smaller) or "stabilityai/sdxl-turbo" (fastest) MODEL_ID = "Lykon/dreamshaper-8-lcm" # Global pipeline with lazy loading pipe = None def init_pipeline(): """Initialize the pipeline with optimizations""" global pipe if pipe is not None: return pipe print(f"Loading model: {MODEL_ID}") # Use half precision for speed and memory efficiency torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 try: # Load pipeline with optimizations pipe = DiffusionPipeline.from_pretrained( MODEL_ID, torch_dtype=torch_dtype, variant="fp16" if torch_dtype == torch.float16 else None, use_safetensors=True, safety_checker=None, # Disable for speed (optional) requires_safety_checker=False ) # Move to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" pipe = pipe.to(device) # Enable optimizations if device == "cuda": pipe.enable_attention_slicing() # Reduce memory usage if torch_dtype == torch.float16: pipe.enable_model_cpu_offload() # Offload to CPU when not in use print(f"Model loaded successfully on {device}") return pipe except Exception as e: print(f"Error loading model: {e}") # Fallback to a simpler model try: pipe = DiffusionPipeline.from_pretrained( "SimianLuo/LCM_Dreamshaper_v7", torch_dtype=torch_dtype ).to("cuda" if torch.cuda.is_available() else "cpu") print("Loaded fallback model") return pipe except: raise Exception("Failed to load any model") # Initialize storage os.makedirs(BASE, exist_ok=True) for path in [WL_PATH, USAGE_PATH, LIMITS_PATH]: if not os.path.exists(path): if path.endswith(".json"): with open(path, "w") as f: json.dump({}, f) else: with open(path, "w") as f: f.write("") # Helper functions def get_whitelist(): try: with open(WL_PATH, "r") as f: return set(line.strip() for line in f if line.strip()) except: return set() def load_json(path): try: with open(path, "r") as f: return json.load(f) except: return {} def save_json(path, data): with open(path, "w") as f: json.dump(data, f, indent=2) def validate_api_key(key): """Validate API key and check rate limits""" if key not in get_whitelist(): return False, "Unauthorized" limits = load_json(LIMITS_PATH) usage = load_json(USAGE_PATH) limit = limits.get(key, DEFAULT_LIMIT) if limit == "unlimited": return True, "OK" month = datetime.now().strftime("%Y-%m") used = usage.get(key, {}).get(month, 0) if used >= limit: return False, "Monthly limit reached" return True, "OK" # Routes @app.route("/", methods=["GET"]) def health(): return jsonify({ "status": "online", "model": MODEL_ID, "device": "cuda" if torch.cuda.is_available() else "cpu" }), 200 @app.route("/generate-key", methods=["POST"]) def generate_key(): try: data = request.get_json() or {} unlimited = data.get("unlimited", False) limit = data.get("limit", DEFAULT_LIMIT) key = "sk-" + secrets.token_hex(16) # Add to whitelist with open(WL_PATH, "a") as f: f.write(key + "\n") # Set limits limits = load_json(LIMITS_PATH) limits[key] = "unlimited" if unlimited else int(limit) save_json(LIMITS_PATH, limits) # Initialize usage usage = load_json(USAGE_PATH) if key not in usage: usage[key] = {} save_json(USAGE_PATH, usage) return jsonify({ "key": key, "limit": limits[key], "message": "Key generated successfully" }) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/generate", methods=["POST"]) def generate(): try: # Validate API key key = request.headers.get("x-api-key", "") valid, message = validate_api_key(key) if not valid: return jsonify({"error": message}), 401 if message == "Unauthorized" else 429 # Parse request data = request.get_json() or {} prompt = data.get("prompt", "").strip() if not prompt: return jsonify({"error": "Prompt is required"}), 400 # Set generation parameters with safe defaults steps = min(max(int(data.get("steps", 4)), 1), 20) # LCM models work with 4-8 steps guidance = float(data.get("guidance", 1.2)) # LCM uses low guidance width = min(max(int(data.get("width", 512)), 256), 1024) height = min(max(int(data.get("height", 512)), 256), 1024) # Ensure pipeline is loaded if pipe is None: init_pipeline() # Generate image print(f"Generating: {prompt[:50]}... (steps: {steps}, guidance: {guidance})") with torch.inference_mode(): image = pipe( prompt=prompt, num_inference_steps=steps, guidance_scale=guidance, width=width, height=height, output_type="pil" ).images[0] # Update usage usage = load_json(USAGE_PATH) month = datetime.now().strftime("%Y-%m") usage.setdefault(key, {}) usage[key][month] = usage[key].get(month, 0) + 1 save_json(USAGE_PATH, usage) # Return image buf = BytesIO() image.save(buf, format="PNG", optimize=True) buf.seek(0) return send_file(buf, mimetype="image/png") except torch.cuda.OutOfMemoryError: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return jsonify({"error": "GPU out of memory. Try smaller image size."}), 507 except Exception as e: error_details = traceback.format_exc() print(f"Generation error: {error_details}") return jsonify({ "error": "Generation failed", "details": str(e) }), 500 @app.route("/api/status", methods=["GET"]) def status(): """Check API key status and usage""" key = request.headers.get("x-api-key", "") if key not in get_whitelist(): return jsonify({"error": "Invalid API key"}), 401 limits = load_json(LIMITS_PATH) usage = load_json(USAGE_PATH) month = datetime.now().strftime("%Y-%m") used = usage.get(key, {}).get(month, 0) limit = limits.get(key, DEFAULT_LIMIT) return jsonify({ "key": key[:8] + "..." + key[-4:] if len(key) > 12 else key, "usage": used, "limit": limit, "remaining": "unlimited" if limit == "unlimited" else max(0, limit - used), "month": month }) if __name__ == "__main__": # Initialize pipeline on startup print("Initializing pipeline...") init_pipeline() print("API starting on port 7860...") app.run(host="0.0.0.0", port=7860, debug=False)