Spaces:
Build error
Build error
| from flask import Flask, request, jsonify, send_file | |
| from flask_cors import CORS | |
| from diffusers import DiffusionPipeline, LCMScheduler | |
| import torch | |
| import os | |
| import json | |
| import secrets | |
| from io import BytesIO | |
| import gc | |
| from datetime import datetime | |
| import traceback | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Configuration | |
| BASE = "/home/sd" | |
| WL_PATH = f"{BASE}/whitelist.txt" | |
| USAGE_PATH = f"{BASE}/usage.json" | |
| LIMITS_PATH = f"{BASE}/limits.json" | |
| DEFAULT_LIMIT = 500 | |
| # Use a fast, reliable model: LCM version for speed + quality | |
| # Alternatives: "segmind/SSD-1B" (smaller) or "stabilityai/sdxl-turbo" (fastest) | |
| MODEL_ID = "Lykon/dreamshaper-8-lcm" | |
| # Global pipeline with lazy loading | |
| pipe = None | |
| def init_pipeline(): | |
| """Initialize the pipeline with optimizations""" | |
| global pipe | |
| if pipe is not None: | |
| return pipe | |
| print(f"Loading model: {MODEL_ID}") | |
| # Use half precision for speed and memory efficiency | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| try: | |
| # Load pipeline with optimizations | |
| pipe = DiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch_dtype, | |
| variant="fp16" if torch_dtype == torch.float16 else None, | |
| use_safetensors=True, | |
| safety_checker=None, # Disable for speed (optional) | |
| requires_safety_checker=False | |
| ) | |
| # Move to GPU if available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = pipe.to(device) | |
| # Enable optimizations | |
| if device == "cuda": | |
| pipe.enable_attention_slicing() # Reduce memory usage | |
| if torch_dtype == torch.float16: | |
| pipe.enable_model_cpu_offload() # Offload to CPU when not in use | |
| print(f"Model loaded successfully on {device}") | |
| return pipe | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Fallback to a simpler model | |
| try: | |
| pipe = DiffusionPipeline.from_pretrained( | |
| "SimianLuo/LCM_Dreamshaper_v7", | |
| torch_dtype=torch_dtype | |
| ).to("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Loaded fallback model") | |
| return pipe | |
| except: | |
| raise Exception("Failed to load any model") | |
| # Initialize storage | |
| os.makedirs(BASE, exist_ok=True) | |
| for path in [WL_PATH, USAGE_PATH, LIMITS_PATH]: | |
| if not os.path.exists(path): | |
| if path.endswith(".json"): | |
| with open(path, "w") as f: | |
| json.dump({}, f) | |
| else: | |
| with open(path, "w") as f: | |
| f.write("") | |
| # Helper functions | |
| def get_whitelist(): | |
| try: | |
| with open(WL_PATH, "r") as f: | |
| return set(line.strip() for line in f if line.strip()) | |
| except: | |
| return set() | |
| def load_json(path): | |
| try: | |
| with open(path, "r") as f: | |
| return json.load(f) | |
| except: | |
| return {} | |
| def save_json(path, data): | |
| with open(path, "w") as f: | |
| json.dump(data, f, indent=2) | |
| def validate_api_key(key): | |
| """Validate API key and check rate limits""" | |
| if key not in get_whitelist(): | |
| return False, "Unauthorized" | |
| limits = load_json(LIMITS_PATH) | |
| usage = load_json(USAGE_PATH) | |
| limit = limits.get(key, DEFAULT_LIMIT) | |
| if limit == "unlimited": | |
| return True, "OK" | |
| month = datetime.now().strftime("%Y-%m") | |
| used = usage.get(key, {}).get(month, 0) | |
| if used >= limit: | |
| return False, "Monthly limit reached" | |
| return True, "OK" | |
| # Routes | |
| def health(): | |
| return jsonify({ | |
| "status": "online", | |
| "model": MODEL_ID, | |
| "device": "cuda" if torch.cuda.is_available() else "cpu" | |
| }), 200 | |
| def generate_key(): | |
| try: | |
| data = request.get_json() or {} | |
| unlimited = data.get("unlimited", False) | |
| limit = data.get("limit", DEFAULT_LIMIT) | |
| key = "sk-" + secrets.token_hex(16) | |
| # Add to whitelist | |
| with open(WL_PATH, "a") as f: | |
| f.write(key + "\n") | |
| # Set limits | |
| limits = load_json(LIMITS_PATH) | |
| limits[key] = "unlimited" if unlimited else int(limit) | |
| save_json(LIMITS_PATH, limits) | |
| # Initialize usage | |
| usage = load_json(USAGE_PATH) | |
| if key not in usage: | |
| usage[key] = {} | |
| save_json(USAGE_PATH, usage) | |
| return jsonify({ | |
| "key": key, | |
| "limit": limits[key], | |
| "message": "Key generated successfully" | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def generate(): | |
| try: | |
| # Validate API key | |
| key = request.headers.get("x-api-key", "") | |
| valid, message = validate_api_key(key) | |
| if not valid: | |
| return jsonify({"error": message}), 401 if message == "Unauthorized" else 429 | |
| # Parse request | |
| data = request.get_json() or {} | |
| prompt = data.get("prompt", "").strip() | |
| if not prompt: | |
| return jsonify({"error": "Prompt is required"}), 400 | |
| # Set generation parameters with safe defaults | |
| steps = min(max(int(data.get("steps", 4)), 1), 20) # LCM models work with 4-8 steps | |
| guidance = float(data.get("guidance", 1.2)) # LCM uses low guidance | |
| width = min(max(int(data.get("width", 512)), 256), 1024) | |
| height = min(max(int(data.get("height", 512)), 256), 1024) | |
| # Ensure pipeline is loaded | |
| if pipe is None: | |
| init_pipeline() | |
| # Generate image | |
| print(f"Generating: {prompt[:50]}... (steps: {steps}, guidance: {guidance})") | |
| with torch.inference_mode(): | |
| image = pipe( | |
| prompt=prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance, | |
| width=width, | |
| height=height, | |
| output_type="pil" | |
| ).images[0] | |
| # Update usage | |
| usage = load_json(USAGE_PATH) | |
| month = datetime.now().strftime("%Y-%m") | |
| usage.setdefault(key, {}) | |
| usage[key][month] = usage[key].get(month, 0) + 1 | |
| save_json(USAGE_PATH, usage) | |
| # Return image | |
| buf = BytesIO() | |
| image.save(buf, format="PNG", optimize=True) | |
| buf.seek(0) | |
| return send_file(buf, mimetype="image/png") | |
| except torch.cuda.OutOfMemoryError: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return jsonify({"error": "GPU out of memory. Try smaller image size."}), 507 | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| print(f"Generation error: {error_details}") | |
| return jsonify({ | |
| "error": "Generation failed", | |
| "details": str(e) | |
| }), 500 | |
| def status(): | |
| """Check API key status and usage""" | |
| key = request.headers.get("x-api-key", "") | |
| if key not in get_whitelist(): | |
| return jsonify({"error": "Invalid API key"}), 401 | |
| limits = load_json(LIMITS_PATH) | |
| usage = load_json(USAGE_PATH) | |
| month = datetime.now().strftime("%Y-%m") | |
| used = usage.get(key, {}).get(month, 0) | |
| limit = limits.get(key, DEFAULT_LIMIT) | |
| return jsonify({ | |
| "key": key[:8] + "..." + key[-4:] if len(key) > 12 else key, | |
| "usage": used, | |
| "limit": limit, | |
| "remaining": "unlimited" if limit == "unlimited" else max(0, limit - used), | |
| "month": month | |
| }) | |
| if __name__ == "__main__": | |
| # Initialize pipeline on startup | |
| print("Initializing pipeline...") | |
| init_pipeline() | |
| print("API starting on port 7860...") | |
| app.run(host="0.0.0.0", port=7860, debug=False) |