guydffdsdsfd commited on
Commit
277eaa3
·
verified ·
1 Parent(s): 70a06fc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +265 -0
app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_file
2
+ from flask_cors import CORS
3
+ from diffusers import DiffusionPipeline, LCMScheduler
4
+ import torch
5
+ import os
6
+ import json
7
+ import secrets
8
+ from io import BytesIO
9
+ import gc
10
+ from datetime import datetime
11
+ import traceback
12
+
13
+ app = Flask(__name__)
14
+ CORS(app)
15
+
16
+ # Configuration
17
+ BASE = "/home/sd"
18
+ WL_PATH = f"{BASE}/whitelist.txt"
19
+ USAGE_PATH = f"{BASE}/usage.json"
20
+ LIMITS_PATH = f"{BASE}/limits.json"
21
+ DEFAULT_LIMIT = 500
22
+
23
+ # Use a fast, reliable model: LCM version for speed + quality
24
+ # Alternatives: "segmind/SSD-1B" (smaller) or "stabilityai/sdxl-turbo" (fastest)
25
+ MODEL_ID = "Lykon/dreamshaper-8-lcm"
26
+
27
+ # Global pipeline with lazy loading
28
+ pipe = None
29
+
30
+ def init_pipeline():
31
+ """Initialize the pipeline with optimizations"""
32
+ global pipe
33
+
34
+ if pipe is not None:
35
+ return pipe
36
+
37
+ print(f"Loading model: {MODEL_ID}")
38
+
39
+ # Use half precision for speed and memory efficiency
40
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
41
+
42
+ try:
43
+ # Load pipeline with optimizations
44
+ pipe = DiffusionPipeline.from_pretrained(
45
+ MODEL_ID,
46
+ torch_dtype=torch_dtype,
47
+ variant="fp16" if torch_dtype == torch.float16 else None,
48
+ use_safetensors=True,
49
+ safety_checker=None, # Disable for speed (optional)
50
+ requires_safety_checker=False
51
+ )
52
+
53
+ # Move to GPU if available
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ pipe = pipe.to(device)
56
+
57
+ # Enable optimizations
58
+ if device == "cuda":
59
+ pipe.enable_attention_slicing() # Reduce memory usage
60
+ if torch_dtype == torch.float16:
61
+ pipe.enable_model_cpu_offload() # Offload to CPU when not in use
62
+
63
+ print(f"Model loaded successfully on {device}")
64
+ return pipe
65
+
66
+ except Exception as e:
67
+ print(f"Error loading model: {e}")
68
+ # Fallback to a simpler model
69
+ try:
70
+ pipe = DiffusionPipeline.from_pretrained(
71
+ "SimianLuo/LCM_Dreamshaper_v7",
72
+ torch_dtype=torch_dtype
73
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
74
+ print("Loaded fallback model")
75
+ return pipe
76
+ except:
77
+ raise Exception("Failed to load any model")
78
+
79
+ # Initialize storage
80
+ os.makedirs(BASE, exist_ok=True)
81
+ for path in [WL_PATH, USAGE_PATH, LIMITS_PATH]:
82
+ if not os.path.exists(path):
83
+ if path.endswith(".json"):
84
+ with open(path, "w") as f:
85
+ json.dump({}, f)
86
+ else:
87
+ with open(path, "w") as f:
88
+ f.write("")
89
+
90
+ # Helper functions
91
+ def get_whitelist():
92
+ try:
93
+ with open(WL_PATH, "r") as f:
94
+ return set(line.strip() for line in f if line.strip())
95
+ except:
96
+ return set()
97
+
98
+ def load_json(path):
99
+ try:
100
+ with open(path, "r") as f:
101
+ return json.load(f)
102
+ except:
103
+ return {}
104
+
105
+ def save_json(path, data):
106
+ with open(path, "w") as f:
107
+ json.dump(data, f, indent=2)
108
+
109
+ def validate_api_key(key):
110
+ """Validate API key and check rate limits"""
111
+ if key not in get_whitelist():
112
+ return False, "Unauthorized"
113
+
114
+ limits = load_json(LIMITS_PATH)
115
+ usage = load_json(USAGE_PATH)
116
+
117
+ limit = limits.get(key, DEFAULT_LIMIT)
118
+ if limit == "unlimited":
119
+ return True, "OK"
120
+
121
+ month = datetime.now().strftime("%Y-%m")
122
+ used = usage.get(key, {}).get(month, 0)
123
+
124
+ if used >= limit:
125
+ return False, "Monthly limit reached"
126
+
127
+ return True, "OK"
128
+
129
+ # Routes
130
+ @app.route("/", methods=["GET"])
131
+ def health():
132
+ return jsonify({
133
+ "status": "online",
134
+ "model": MODEL_ID,
135
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
136
+ }), 200
137
+
138
+ @app.route("/generate-key", methods=["POST"])
139
+ def generate_key():
140
+ try:
141
+ data = request.get_json() or {}
142
+ unlimited = data.get("unlimited", False)
143
+ limit = data.get("limit", DEFAULT_LIMIT)
144
+
145
+ key = "sk-" + secrets.token_hex(16)
146
+
147
+ # Add to whitelist
148
+ with open(WL_PATH, "a") as f:
149
+ f.write(key + "\n")
150
+
151
+ # Set limits
152
+ limits = load_json(LIMITS_PATH)
153
+ limits[key] = "unlimited" if unlimited else int(limit)
154
+ save_json(LIMITS_PATH, limits)
155
+
156
+ # Initialize usage
157
+ usage = load_json(USAGE_PATH)
158
+ if key not in usage:
159
+ usage[key] = {}
160
+ save_json(USAGE_PATH, usage)
161
+
162
+ return jsonify({
163
+ "key": key,
164
+ "limit": limits[key],
165
+ "message": "Key generated successfully"
166
+ })
167
+
168
+ except Exception as e:
169
+ return jsonify({"error": str(e)}), 500
170
+
171
+ @app.route("/api/generate", methods=["POST"])
172
+ def generate():
173
+ try:
174
+ # Validate API key
175
+ key = request.headers.get("x-api-key", "")
176
+ valid, message = validate_api_key(key)
177
+ if not valid:
178
+ return jsonify({"error": message}), 401 if message == "Unauthorized" else 429
179
+
180
+ # Parse request
181
+ data = request.get_json() or {}
182
+ prompt = data.get("prompt", "").strip()
183
+
184
+ if not prompt:
185
+ return jsonify({"error": "Prompt is required"}), 400
186
+
187
+ # Set generation parameters with safe defaults
188
+ steps = min(max(int(data.get("steps", 4)), 1), 20) # LCM models work with 4-8 steps
189
+ guidance = float(data.get("guidance", 1.2)) # LCM uses low guidance
190
+ width = min(max(int(data.get("width", 512)), 256), 1024)
191
+ height = min(max(int(data.get("height", 512)), 256), 1024)
192
+
193
+ # Ensure pipeline is loaded
194
+ if pipe is None:
195
+ init_pipeline()
196
+
197
+ # Generate image
198
+ print(f"Generating: {prompt[:50]}... (steps: {steps}, guidance: {guidance})")
199
+
200
+ with torch.inference_mode():
201
+ image = pipe(
202
+ prompt=prompt,
203
+ num_inference_steps=steps,
204
+ guidance_scale=guidance,
205
+ width=width,
206
+ height=height,
207
+ output_type="pil"
208
+ ).images[0]
209
+
210
+ # Update usage
211
+ usage = load_json(USAGE_PATH)
212
+ month = datetime.now().strftime("%Y-%m")
213
+ usage.setdefault(key, {})
214
+ usage[key][month] = usage[key].get(month, 0) + 1
215
+ save_json(USAGE_PATH, usage)
216
+
217
+ # Return image
218
+ buf = BytesIO()
219
+ image.save(buf, format="PNG", optimize=True)
220
+ buf.seek(0)
221
+
222
+ return send_file(buf, mimetype="image/png")
223
+
224
+ except torch.cuda.OutOfMemoryError:
225
+ gc.collect()
226
+ if torch.cuda.is_available():
227
+ torch.cuda.empty_cache()
228
+ return jsonify({"error": "GPU out of memory. Try smaller image size."}), 507
229
+
230
+ except Exception as e:
231
+ error_details = traceback.format_exc()
232
+ print(f"Generation error: {error_details}")
233
+ return jsonify({
234
+ "error": "Generation failed",
235
+ "details": str(e)
236
+ }), 500
237
+
238
+ @app.route("/api/status", methods=["GET"])
239
+ def status():
240
+ """Check API key status and usage"""
241
+ key = request.headers.get("x-api-key", "")
242
+ if key not in get_whitelist():
243
+ return jsonify({"error": "Invalid API key"}), 401
244
+
245
+ limits = load_json(LIMITS_PATH)
246
+ usage = load_json(USAGE_PATH)
247
+
248
+ month = datetime.now().strftime("%Y-%m")
249
+ used = usage.get(key, {}).get(month, 0)
250
+ limit = limits.get(key, DEFAULT_LIMIT)
251
+
252
+ return jsonify({
253
+ "key": key[:8] + "..." + key[-4:] if len(key) > 12 else key,
254
+ "usage": used,
255
+ "limit": limit,
256
+ "remaining": "unlimited" if limit == "unlimited" else max(0, limit - used),
257
+ "month": month
258
+ })
259
+
260
+ if __name__ == "__main__":
261
+ # Initialize pipeline on startup
262
+ print("Initializing pipeline...")
263
+ init_pipeline()
264
+ print("API starting on port 7860...")
265
+ app.run(host="0.0.0.0", port=7860, debug=False)