File size: 4,215 Bytes
b82afe8
 
c697687
 
7521234
c697687
b82afe8
 
 
c697687
b82afe8
 
85a948d
c697687
 
85a948d
 
 
 
c697687
 
 
85a948d
 
 
7521234
 
 
 
622a459
7521234
 
 
c697687
 
85a948d
c697687
7521234
b82afe8
622a459
c697687
7521234
 
 
 
 
 
85a948d
622a459
7521234
 
631c6d3
7521234
 
 
 
 
 
 
 
 
 
 
 
 
85a948d
c697687
 
 
 
 
85a948d
c697687
 
 
7521234
85a948d
7521234
85a948d
c697687
 
 
85a948d
7521234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85a948d
7521234
 
c697687
7521234
1d9f31c
7521234
85a948d
7521234
 
 
 
 
 
 
 
 
 
c697687
 
7521234
 
 
c697687
7521234
 
85a948d
c697687
1d9f31c
7521234
 
 
 
 
1d9f31c
 
7521234
c697687
1d9f31c
 
 
 
 
 
 
c697687
7521234
 
 
 
 
 
 
 
c697687
7521234
 
c697687
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
FROM python:3.10-slim

ENV DEBIAN_FRONTEND=noninteractive

# ---------------- System deps ----------------
RUN apt-get update && apt-get install -y --no-install-recommends \
    git \
    libgl1 \
    libglib2.0-0 \
    ca-certificates \
    && rm -rf /var/lib/apt/lists/*

# ---------------- Python deps (PINNED & COMPATIBLE) ----------------
RUN pip install --no-cache-dir --upgrade pip && \
    pip install --no-cache-dir \
    numpy<2 \
    torch==2.0.1 \
    torchvision==0.15.2 \
    torchaudio==2.0.2 \
    --index-url https://download.pytorch.org/whl/cpu

RUN pip install --no-cache-dir \
    diffusers==0.24.0 \
    transformers==4.36.2 \
    accelerate==0.25.0 \
    safetensors \
    flask \
    flask-cors \
    pillow

# ---------------- Env ----------------
ENV HOME=/home/sd
ENV HF_HOME=/home/sd/.cache
ENV OMP_NUM_THREADS=1
ENV MKL_NUM_THREADS=1
ENV NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=0

# ---------------- Storage ----------------
RUN mkdir -p /home/sd && chmod -R 777 /home/sd

# ---------------- App ----------------
RUN cat <<'EOF' > /app.py
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from diffusers import DiffusionPipeline, LCMScheduler
import torch, os, json, secrets
from io import BytesIO
from datetime import datetime

app = Flask(__name__)
CORS(app)

BASE = "/home/sd"
WL_PATH = f"{BASE}/whitelist.txt"
USAGE_PATH = f"{BASE}/usage.json"
LIMITS_PATH = f"{BASE}/limits.json"

DEFAULT_LIMIT = 500
MODEL_ID = "SimianLuo/LCM_Dreamshaper_v7"

os.makedirs(BASE, exist_ok=True)
for p in [WL_PATH, USAGE_PATH, LIMITS_PATH]:
    if not os.path.exists(p):
        open(p, "w").write("{}" if p.endswith(".json") else "")

print(f"Loading model: {MODEL_ID}")

torch.set_grad_enabled(False)

pipe = DiffusionPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    safety_checker=None
)

pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cpu")

# ---- MEMORY + SPEED ----
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()

print("Model loaded (CPU, optimized)")

def whitelist():
    try:
        return set(open(WL_PATH).read().split())
    except:
        return set()

def load_json(path):
    try:
        return json.load(open(path))
    except:
        return {}

def save_json(path, data):
    json.dump(data, open(path, "w"))

@app.route("/", methods=["GET"])
def health():
    return "LCM Image API Running", 200

@app.route("/generate-key", methods=["POST"])
def generate_key():
    data = request.get_json() or {}
    key = "sk-" + secrets.token_hex(16)

    with open(WL_PATH, "a") as f:
        f.write(key + "\n")

    limits = load_json(LIMITS_PATH)
    limits[key] = "unlimited" if data.get("unlimited") else int(data.get("limit", DEFAULT_LIMIT))
    save_json(LIMITS_PATH, limits)

    return jsonify({"key": key, "limit": limits[key]})

@app.route("/api/generate", methods=["POST"])
def generate():
    key = request.headers.get("x-api-key", "")
    if key not in whitelist():
        return jsonify({"error": "Unauthorized"}), 401

    data = request.get_json() or {}
    prompt = data.get("prompt", "").strip()
    if not prompt:
        return jsonify({"error": "Prompt required"}), 400

    limits = load_json(LIMITS_PATH)
    usage = load_json(USAGE_PATH)

    month = datetime.now().strftime("%Y-%m")
    used = usage.get(key, {}).get(month, 0)
    limit = limits.get(key, DEFAULT_LIMIT)

    if limit != "unlimited" and used >= limit:
        return jsonify({"error": "Monthly limit reached"}), 429

    try:
        image = pipe(
            prompt=prompt,
            num_inference_steps=4,
            guidance_scale=1.5
        ).images[0]

        usage.setdefault(key, {})[month] = used + 1
        save_json(USAGE_PATH, usage)

        buf = BytesIO()
        image.save(buf, format="PNG")
        buf.seek(0)
        return send_file(buf, mimetype="image/png")

    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)
EOF

# ---------------- Start ----------------
RUN echo '#!/bin/bash\npython3 /app.py' > /start.sh && chmod +x /start.sh

EXPOSE 7860
ENTRYPOINT ["/bin/bash", "/start.sh"]