| from flask import Flask, render_template, request, redirect, url_for, jsonify, Request, Response | |
| import json | |
| import base64 | |
| import hashlib | |
| import json | |
| import time | |
| import os | |
| import queue | |
| import zstandard as zstd | |
| q = queue.Queue() | |
| """ | |
| Format: | |
| { | |
| "status": "queue" | "progress" | "done" | |
| "prompt": "Some ducks..." | |
| "id": "abc123" | |
| } | |
| """ | |
| models = [] | |
| WORKERS = [] | |
| def enqueue(prompt: str): | |
| tm = time.time() | |
| hsh = hashlib.sha256(prompt.encode("utf-8")).hexdigest() | |
| md = { | |
| "status": "queue", | |
| "prompt": prompt, | |
| "id": f"{hsh}.{tm}" | |
| } | |
| models.append(md) | |
| q.put(json.dumps(md)) | |
| return jsonify({ | |
| "status": "ok" | |
| }) | |
| def dequeue(): | |
| if not q.empty(): | |
| pr = json.loads(q.get_nowait()) | |
| return jsonify({ | |
| "status": "ok", | |
| "prompt": pr["prompt"], | |
| "id": pr["id"] | |
| }) | |
| return jsonify({ | |
| "status": "empty" | |
| }) | |
| def complete(data): | |
| jsn = json.loads(data) | |
| for i in range(len(models)): | |
| if models[i]["id"] == jsn["_id"]: | |
| models[i]["status"] = "done" | |
| for fl in jsn["files"]: | |
| rd = zstd.decompress(base64.b64decode(fl["data"])) | |
| os.makedirs(f"files/{fl['path']}", exist_ok=True) | |
| os.rmdir(f"files/{fl['path']}") | |
| with open(f"files/{fl['path']}", "wb") as f: | |
| f.write(rd) | |
| f.flush() | |
| f.close() | |
| break | |
| return jsonify({"status": "ok"}) | |
| def worker(): | |
| while True: | |
| if not q.empty(): | |
| pr = json.loads(q.get_nowait()) | |
| pr["status"] = "progress" | |
| for w in WORKERS: | |
| if w["status"] == "idle": | |
| w["prompt"] = pr["prompt"] | |
| w["id"] = pr["id"] | |
| break | |
| else: | |
| q.put(jsonify(pr)) | |
| time.sleep(1) | |
| app = Flask(__name__) | |
| if __name__ == "__main__": | |
| app.add_url_rule("/enqueue/<prompt>", "enqueue", enqueue, methods=["POST"]) | |
| app.add_url_rule("/dequeue", "dequeue", dequeue, methods=["GET"]) | |
| app.add_url_rule("/complete", "complete", complete, methods=["POST"]) | |
| app.add_url_rule("/", "index", lambda: """ | |
| <html> | |
| <head> | |
| <title>Snail</title> | |
| </head> | |
| <body> | |
| <h1>Snail</h1> | |
| </html> | |
| """, methods=["GET"]) | |
| app.static_folder = "public" | |
| app.run(port=7860, host="0.0.0.0") |