Spaces:
Runtime error
Runtime error
| from flask import Flask, request, jsonify, send_file | |
| from flask_cors import CORS | |
| import os, torch, io, time, threading | |
| from diffusers import StableDiffusionPipeline | |
| from queue import Queue | |
| app = Flask(__name__) | |
| CORS(app, resources={ | |
| r"/*": { | |
| "origins": ["*"], | |
| "methods": ["GET", "POST", "OPTIONS"], | |
| "allow_headers": ["Content-Type", "Authorization", "x-api-key"] | |
| } | |
| }) | |
| WL_PATH = "whitelist.txt" | |
| UNLIMITED_KEY = "sk-ess4l0ri37" | |
| TRUSTED_DOMAINS = ["kaigpt.vercel.app", "localhost"] | |
| image_progress = {} | |
| progress_lock = threading.Lock() | |
| task_queue = Queue() | |
| print("Loading Stable Diffusion...") | |
| try: | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float32, | |
| safety_checker=None, | |
| requires_safety_checker=False | |
| ).to("cpu") | |
| print("Model loaded") | |
| except Exception as e: | |
| print("Model failed:", e) | |
| pipe = None | |
| def get_whitelist(): | |
| if not os.path.exists(WL_PATH): | |
| return {UNLIMITED_KEY} | |
| with open(WL_PATH) as f: | |
| return set(x.strip() for x in f if x.strip()) | |
| def is_trusted_origin(): | |
| origin = request.headers.get("Origin", "") | |
| referer = request.headers.get("Referer", "") | |
| for d in TRUSTED_DOMAINS: | |
| if d in origin or d in referer: | |
| return True | |
| return False | |
| def update_progress(request_id, progress, status): | |
| with progress_lock: | |
| image_progress[request_id] = { | |
| "progress": progress, | |
| "status": status, | |
| "timestamp": time.time() | |
| } | |
| def cleanup_old_progress(): | |
| with progress_lock: | |
| now = time.time() | |
| for k in list(image_progress.keys()): | |
| if now - image_progress[k]["timestamp"] > 300: | |
| del image_progress[k] | |
| def worker(): | |
| while True: | |
| job = task_queue.get() | |
| if job is None: | |
| break | |
| request_id = job["request_id"] | |
| prompt = job["prompt"] | |
| steps = job["steps"] | |
| update_progress(request_id, 0, "Queued") | |
| try: | |
| def cb(step, timestep, latents): | |
| progress = int((step / steps) * 100) | |
| update_progress(request_id, progress, f"Step {step}/{steps}") | |
| with torch.no_grad(): | |
| img = pipe( | |
| prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=7.5, | |
| callback=cb, | |
| callback_steps=1 | |
| ).images[0] | |
| bio = io.BytesIO() | |
| img.save(bio, "PNG") | |
| bio.seek(0) | |
| update_progress(request_id, 100, "Complete") | |
| job["result"] = bio | |
| except Exception as e: | |
| job["error"] = str(e) | |
| update_progress(request_id, 0, "Error") | |
| task_queue.task_done() | |
| threading.Thread(target=worker, daemon=True).start() | |
| def txt2img(): | |
| if request.method == "OPTIONS": | |
| return jsonify({"status": "ok"}) | |
| if not is_trusted_origin(): | |
| api_key = request.headers.get("x-api-key") or request.json.get("api_key", "") | |
| if api_key not in get_whitelist(): | |
| return jsonify({"error": "Unauthorized"}), 401 | |
| if not pipe: | |
| return jsonify({"error": "Model not loaded"}), 503 | |
| data = request.get_json(force=True) or {} | |
| prompt = data.get("prompt", "a beautiful landscape") | |
| steps = min(max(int(data.get("steps", 25)), 10), 50) | |
| request_id = data.get("request_id", f"img_{int(time.time())}_{hash(prompt)%10000}") | |
| cleanup_old_progress() | |
| job = { | |
| "request_id": request_id, | |
| "prompt": prompt, | |
| "steps": steps, | |
| "result": None, | |
| "error": None | |
| } | |
| update_progress(request_id, 0, "Waiting in queue") | |
| task_queue.put(job) | |
| while job["result"] is None and job["error"] is None: | |
| time.sleep(0.1) | |
| if job["error"]: | |
| return jsonify({ | |
| "error": "Generation failed", | |
| "message": job["error"], | |
| "request_id": request_id | |
| }), 500 | |
| return send_file( | |
| job["result"], | |
| mimetype="image/png", | |
| download_name="image.png" | |
| ) | |
| def img_progress(request_id): | |
| cleanup_old_progress() | |
| with progress_lock: | |
| return jsonify(image_progress.get(request_id, { | |
| "progress": 0, | |
| "status": "Not found" | |
| })) | |
| def health(): | |
| return jsonify({ | |
| "status": "online" if pipe else "offline", | |
| "model_loaded": pipe is not None | |
| }) | |
| if __name__ == "__main__": | |
| if not os.path.exists(WL_PATH): | |
| with open(WL_PATH, "w") as f: | |
| f.write(UNLIMITED_KEY + "\n") | |
| app.run(host="0.0.0.0", port=7860, threaded=True) | |