ollama-picture / app.py
guydffdsdsfd's picture
Update app.py
01c0349 verified
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import os, torch, io, time, threading
from diffusers import StableDiffusionPipeline
from queue import Queue
app = Flask(__name__)
CORS(app, resources={
r"/*": {
"origins": ["*"],
"methods": ["GET", "POST", "OPTIONS"],
"allow_headers": ["Content-Type", "Authorization", "x-api-key"]
}
})
WL_PATH = "whitelist.txt"
UNLIMITED_KEY = "sk-ess4l0ri37"
TRUSTED_DOMAINS = ["kaigpt.vercel.app", "localhost"]
image_progress = {}
progress_lock = threading.Lock()
task_queue = Queue()
print("Loading Stable Diffusion...")
try:
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float32,
safety_checker=None,
requires_safety_checker=False
).to("cpu")
print("Model loaded")
except Exception as e:
print("Model failed:", e)
pipe = None
def get_whitelist():
if not os.path.exists(WL_PATH):
return {UNLIMITED_KEY}
with open(WL_PATH) as f:
return set(x.strip() for x in f if x.strip())
def is_trusted_origin():
origin = request.headers.get("Origin", "")
referer = request.headers.get("Referer", "")
for d in TRUSTED_DOMAINS:
if d in origin or d in referer:
return True
return False
def update_progress(request_id, progress, status):
with progress_lock:
image_progress[request_id] = {
"progress": progress,
"status": status,
"timestamp": time.time()
}
def cleanup_old_progress():
with progress_lock:
now = time.time()
for k in list(image_progress.keys()):
if now - image_progress[k]["timestamp"] > 300:
del image_progress[k]
def worker():
while True:
job = task_queue.get()
if job is None:
break
request_id = job["request_id"]
prompt = job["prompt"]
steps = job["steps"]
update_progress(request_id, 0, "Queued")
try:
def cb(step, timestep, latents):
progress = int((step / steps) * 100)
update_progress(request_id, progress, f"Step {step}/{steps}")
with torch.no_grad():
img = pipe(
prompt,
num_inference_steps=steps,
guidance_scale=7.5,
callback=cb,
callback_steps=1
).images[0]
bio = io.BytesIO()
img.save(bio, "PNG")
bio.seek(0)
update_progress(request_id, 100, "Complete")
job["result"] = bio
except Exception as e:
job["error"] = str(e)
update_progress(request_id, 0, "Error")
task_queue.task_done()
threading.Thread(target=worker, daemon=True).start()
@app.route("/api/txt2img", methods=["POST", "OPTIONS"])
def txt2img():
if request.method == "OPTIONS":
return jsonify({"status": "ok"})
if not is_trusted_origin():
api_key = request.headers.get("x-api-key") or request.json.get("api_key", "")
if api_key not in get_whitelist():
return jsonify({"error": "Unauthorized"}), 401
if not pipe:
return jsonify({"error": "Model not loaded"}), 503
data = request.get_json(force=True) or {}
prompt = data.get("prompt", "a beautiful landscape")
steps = min(max(int(data.get("steps", 25)), 10), 50)
request_id = data.get("request_id", f"img_{int(time.time())}_{hash(prompt)%10000}")
cleanup_old_progress()
job = {
"request_id": request_id,
"prompt": prompt,
"steps": steps,
"result": None,
"error": None
}
update_progress(request_id, 0, "Waiting in queue")
task_queue.put(job)
while job["result"] is None and job["error"] is None:
time.sleep(0.1)
if job["error"]:
return jsonify({
"error": "Generation failed",
"message": job["error"],
"request_id": request_id
}), 500
return send_file(
job["result"],
mimetype="image/png",
download_name="image.png"
)
@app.route("/api/img_progress/<request_id>")
def img_progress(request_id):
cleanup_old_progress()
with progress_lock:
return jsonify(image_progress.get(request_id, {
"progress": 0,
"status": "Not found"
}))
@app.route("/api/health")
def health():
return jsonify({
"status": "online" if pipe else "offline",
"model_loaded": pipe is not None
})
if __name__ == "__main__":
if not os.path.exists(WL_PATH):
with open(WL_PATH, "w") as f:
f.write(UNLIMITED_KEY + "\n")
app.run(host="0.0.0.0", port=7860, threaded=True)