image / app.py
guydffdsdsfd's picture
Create app.py
277eaa3 verified
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from diffusers import DiffusionPipeline, LCMScheduler
import torch
import os
import json
import secrets
from io import BytesIO
import gc
from datetime import datetime
import traceback
app = Flask(__name__)
CORS(app)
# Configuration
BASE = "/home/sd"
WL_PATH = f"{BASE}/whitelist.txt"
USAGE_PATH = f"{BASE}/usage.json"
LIMITS_PATH = f"{BASE}/limits.json"
DEFAULT_LIMIT = 500
# Use a fast, reliable model: LCM version for speed + quality
# Alternatives: "segmind/SSD-1B" (smaller) or "stabilityai/sdxl-turbo" (fastest)
MODEL_ID = "Lykon/dreamshaper-8-lcm"
# Global pipeline with lazy loading
pipe = None
def init_pipeline():
"""Initialize the pipeline with optimizations"""
global pipe
if pipe is not None:
return pipe
print(f"Loading model: {MODEL_ID}")
# Use half precision for speed and memory efficiency
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
try:
# Load pipeline with optimizations
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else None,
use_safetensors=True,
safety_checker=None, # Disable for speed (optional)
requires_safety_checker=False
)
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)
# Enable optimizations
if device == "cuda":
pipe.enable_attention_slicing() # Reduce memory usage
if torch_dtype == torch.float16:
pipe.enable_model_cpu_offload() # Offload to CPU when not in use
print(f"Model loaded successfully on {device}")
return pipe
except Exception as e:
print(f"Error loading model: {e}")
# Fallback to a simpler model
try:
pipe = DiffusionPipeline.from_pretrained(
"SimianLuo/LCM_Dreamshaper_v7",
torch_dtype=torch_dtype
).to("cuda" if torch.cuda.is_available() else "cpu")
print("Loaded fallback model")
return pipe
except:
raise Exception("Failed to load any model")
# Initialize storage
os.makedirs(BASE, exist_ok=True)
for path in [WL_PATH, USAGE_PATH, LIMITS_PATH]:
if not os.path.exists(path):
if path.endswith(".json"):
with open(path, "w") as f:
json.dump({}, f)
else:
with open(path, "w") as f:
f.write("")
# Helper functions
def get_whitelist():
try:
with open(WL_PATH, "r") as f:
return set(line.strip() for line in f if line.strip())
except:
return set()
def load_json(path):
try:
with open(path, "r") as f:
return json.load(f)
except:
return {}
def save_json(path, data):
with open(path, "w") as f:
json.dump(data, f, indent=2)
def validate_api_key(key):
"""Validate API key and check rate limits"""
if key not in get_whitelist():
return False, "Unauthorized"
limits = load_json(LIMITS_PATH)
usage = load_json(USAGE_PATH)
limit = limits.get(key, DEFAULT_LIMIT)
if limit == "unlimited":
return True, "OK"
month = datetime.now().strftime("%Y-%m")
used = usage.get(key, {}).get(month, 0)
if used >= limit:
return False, "Monthly limit reached"
return True, "OK"
# Routes
@app.route("/", methods=["GET"])
def health():
return jsonify({
"status": "online",
"model": MODEL_ID,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}), 200
@app.route("/generate-key", methods=["POST"])
def generate_key():
try:
data = request.get_json() or {}
unlimited = data.get("unlimited", False)
limit = data.get("limit", DEFAULT_LIMIT)
key = "sk-" + secrets.token_hex(16)
# Add to whitelist
with open(WL_PATH, "a") as f:
f.write(key + "\n")
# Set limits
limits = load_json(LIMITS_PATH)
limits[key] = "unlimited" if unlimited else int(limit)
save_json(LIMITS_PATH, limits)
# Initialize usage
usage = load_json(USAGE_PATH)
if key not in usage:
usage[key] = {}
save_json(USAGE_PATH, usage)
return jsonify({
"key": key,
"limit": limits[key],
"message": "Key generated successfully"
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/generate", methods=["POST"])
def generate():
try:
# Validate API key
key = request.headers.get("x-api-key", "")
valid, message = validate_api_key(key)
if not valid:
return jsonify({"error": message}), 401 if message == "Unauthorized" else 429
# Parse request
data = request.get_json() or {}
prompt = data.get("prompt", "").strip()
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
# Set generation parameters with safe defaults
steps = min(max(int(data.get("steps", 4)), 1), 20) # LCM models work with 4-8 steps
guidance = float(data.get("guidance", 1.2)) # LCM uses low guidance
width = min(max(int(data.get("width", 512)), 256), 1024)
height = min(max(int(data.get("height", 512)), 256), 1024)
# Ensure pipeline is loaded
if pipe is None:
init_pipeline()
# Generate image
print(f"Generating: {prompt[:50]}... (steps: {steps}, guidance: {guidance})")
with torch.inference_mode():
image = pipe(
prompt=prompt,
num_inference_steps=steps,
guidance_scale=guidance,
width=width,
height=height,
output_type="pil"
).images[0]
# Update usage
usage = load_json(USAGE_PATH)
month = datetime.now().strftime("%Y-%m")
usage.setdefault(key, {})
usage[key][month] = usage[key].get(month, 0) + 1
save_json(USAGE_PATH, usage)
# Return image
buf = BytesIO()
image.save(buf, format="PNG", optimize=True)
buf.seek(0)
return send_file(buf, mimetype="image/png")
except torch.cuda.OutOfMemoryError:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return jsonify({"error": "GPU out of memory. Try smaller image size."}), 507
except Exception as e:
error_details = traceback.format_exc()
print(f"Generation error: {error_details}")
return jsonify({
"error": "Generation failed",
"details": str(e)
}), 500
@app.route("/api/status", methods=["GET"])
def status():
"""Check API key status and usage"""
key = request.headers.get("x-api-key", "")
if key not in get_whitelist():
return jsonify({"error": "Invalid API key"}), 401
limits = load_json(LIMITS_PATH)
usage = load_json(USAGE_PATH)
month = datetime.now().strftime("%Y-%m")
used = usage.get(key, {}).get(month, 0)
limit = limits.get(key, DEFAULT_LIMIT)
return jsonify({
"key": key[:8] + "..." + key[-4:] if len(key) > 12 else key,
"usage": used,
"limit": limit,
"remaining": "unlimited" if limit == "unlimited" else max(0, limit - used),
"month": month
})
if __name__ == "__main__":
# Initialize pipeline on startup
print("Initializing pipeline...")
init_pipeline()
print("API starting on port 7860...")
app.run(host="0.0.0.0", port=7860, debug=False)