guydffdsdsfd commited on
Commit
7521234
·
verified ·
1 Parent(s): 277eaa3

Update Dockerfile

Browse files
Files changed (1) hide show
  1. Dockerfile +161 -34
Dockerfile CHANGED
@@ -1,48 +1,175 @@
1
  FROM python:3.10-slim
2
 
3
- # System dependencies
4
  RUN apt-get update && apt-get install -y \
5
  git \
6
  libgl1 \
7
  libglib2.0-0 \
8
- wget \
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
- # Optimized Python dependencies
12
- RUN pip install --no-cache-dir \
13
- torch==2.1.2 \
14
- torchvision==0.16.2 \
15
- torchaudio==2.1.2 \
16
- --index-url https://download.pytorch.org/whl/cu118 \
17
- && pip install --no-cache-dir \
18
- diffusers==0.26.3 \
19
- transformers==4.38.2 \
20
- accelerate==0.27.2 \
21
- safetensors==0.4.2 \
22
- flask==3.0.3 \
23
- flask-cors==4.0.0 \
24
- pillow==10.2.0 \
25
- xformers==0.0.24
26
-
27
- # Environment variables for caching
28
- ENV HOME=/home/sd
29
- ENV HF_HOME=/home/sd/.cache/huggingface
30
- ENV TRANSFORMERS_CACHE=/home/sd/.cache/huggingface/models
31
- ENV DIFFUSERS_CACHE=/home/sd/.cache/huggingface/diffusers
32
- ENV PYTHONUNBUFFERED=1
33
- ENV HF_ENDPOINT=https://hf-mirror.com
34
 
35
- # Create directory with proper permissions
 
 
 
36
  RUN mkdir -p /home/sd && chmod -R 777 /home/sd
37
 
38
- # Copy application code
39
- COPY app.py /app.py
 
 
 
 
 
40
 
41
- # Health check
42
- HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
43
- CMD python -c "import requests; requests.get('http://localhost:7860/', timeout=2)"
44
 
45
- EXPOSE 7860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Start command
48
- CMD ["python", "/app.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  FROM python:3.10-slim
2
 
3
+ # ---------------- System deps ----------------
4
  RUN apt-get update && apt-get install -y \
5
  git \
6
  libgl1 \
7
  libglib2.0-0 \
 
8
  && rm -rf /var/lib/apt/lists/*
9
 
10
+ # ---------------- Python deps ----------------
11
+ # We upgrade these to ensure LCM support works
12
+ RUN pip install --no-cache-dir --upgrade \
13
+ torch \
14
+ torchvision \
15
+ torchaudio \
16
+ diffusers["torch"] \
17
+ transformers \
18
+ accelerate \
19
+ safetensors \
20
+ flask \
21
+ flask-cors \
22
+ pillow
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # ---------------- Env ----------------
25
+ ENV HOME=/home/sd
26
+ ENV HF_HOME=/home/sd/.cache
27
+ # ---------------- Storage ----------------
28
  RUN mkdir -p /home/sd && chmod -R 777 /home/sd
29
 
30
+ # ---------------- Python Application (Written directly to file) ----------------
31
+ RUN cat <<'EOF' > /app.py
32
+ from flask import Flask, request, jsonify, send_file
33
+ from flask_cors import CORS
34
+ from diffusers import DiffusionPipeline, LCMScheduler
35
+ import torch, os, json, secrets
36
+ from io import BytesIO
37
 
38
+ app = Flask(__name__)
39
+ CORS(app)
 
40
 
41
+ # -------- Paths --------
42
+ BASE = "/home/sd"
43
+ WL_PATH = f"{BASE}/whitelist.txt"
44
+ USAGE_PATH = f"{BASE}/usage.json"
45
+ LIMITS_PATH = f"{BASE}/limits.json"
46
+
47
+ DEFAULT_LIMIT = 500
48
+
49
+ # -------- Model Config --------
50
+ # LCM Dreamshaper is SD1.5 based (small) and needs only 4-8 steps (fast)
51
+ MODEL_ID = "SimianLuo/LCM_Dreamshaper_v7"
52
+
53
+ # -------- Init storage --------
54
+ os.makedirs(BASE, exist_ok=True)
55
+ for p in [WL_PATH, USAGE_PATH, LIMITS_PATH]:
56
+ if not os.path.exists(p):
57
+ open(p, "w").write("{}" if p.endswith(".json") else "")
58
+
59
+ # -------- Load model once --------
60
+ print(f"Loading {MODEL_ID}...")
61
+ pipe = DiffusionPipeline.from_pretrained(MODEL_ID)
62
+
63
+ # Ensure we use the LCM Scheduler (Fixes the IndexError crash)
64
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
65
+
66
+ # Detect hardware (CPU vs CUDA)
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+ pipe = pipe.to(device)
69
+ print(f"Model loaded on {device}")
70
+
71
+ # -------- Helpers --------
72
+ def whitelist():
73
+ try:
74
+ return set(open(WL_PATH).read().split())
75
+ except:
76
+ return set()
77
+
78
+ def load_json(path):
79
+ try:
80
+ return json.load(open(path))
81
+ except:
82
+ return {}
83
+
84
+ def save_json(path, data):
85
+ json.dump(data, open(path, "w"))
86
+
87
+ # -------- Health --------
88
+ @app.route("/", methods=["GET"])
89
+ def health():
90
+ return "LCM Image API Running", 200
91
+
92
+ # -------- Key generator --------
93
+ @app.route("/generate-key", methods=["POST"])
94
+ def generate_key():
95
+ data = request.get_json() or {}
96
+ unlimited = data.get("unlimited", False)
97
+ limit = data.get("limit", DEFAULT_LIMIT)
98
+
99
+ key = "sk-" + secrets.token_hex(16)
100
 
101
+ with open(WL_PATH, "a") as f:
102
+ f.write(key + "\n")
103
+
104
+ limits = load_json(LIMITS_PATH)
105
+ limits[key] = "unlimited" if unlimited else int(limit)
106
+ save_json(LIMITS_PATH, limits)
107
+
108
+ return jsonify({"key": key, "limit": limits[key]})
109
+
110
+ # -------- Image generation --------
111
+ @app.route("/api/generate", methods=["POST"])
112
+ def generate():
113
+ key = request.headers.get("x-api-key", "")
114
+ if key not in whitelist():
115
+ return jsonify({"error": "Unauthorized"}), 401
116
+
117
+ data = request.get_json() or {}
118
+ prompt = data.get("prompt", "").strip()
119
+
120
+ # LCM Optimization:
121
+ # LCM works best between 4 and 8 steps.
122
+ # We ignore the user's requested 'steps' to ensure speed and stability.
123
+ steps = 4
124
+ guidance = 2.0
125
+
126
+ if not prompt:
127
+ return jsonify({"error": "Prompt required"}), 400
128
+
129
+ # Rate Limiting Logic
130
+ limits = load_json(LIMITS_PATH)
131
+ usage = load_json(USAGE_PATH)
132
+ limit = limits.get(key, DEFAULT_LIMIT)
133
+ unlimited = (limit == "unlimited")
134
+ from datetime import datetime
135
+ month = datetime.now().strftime("%Y-%m")
136
+ used = usage.get(key, {}).get(month, 0)
137
+
138
+ if not unlimited and used >= limit:
139
+ return jsonify({"error": "Monthly limit reached"}), 429
140
+
141
+ # Generate
142
+ try:
143
+ image = pipe(
144
+ prompt=prompt,
145
+ num_inference_steps=steps,
146
+ guidance_scale=guidance
147
+ ).images[0]
148
+ except Exception as e:
149
+ print(f"Generation Error: {e}")
150
+ return jsonify({"error": str(e)}), 500
151
+
152
+ # Save Usage
153
+ usage.setdefault(key, {})[month] = used + 1
154
+ save_json(USAGE_PATH, usage)
155
+
156
+ # Return Image
157
+ buf = BytesIO()
158
+ image.save(buf, format="PNG")
159
+ buf.seek(0)
160
+ return send_file(buf, mimetype="image/png")
161
+
162
+ if __name__ == "__main__":
163
+ app.run(host="0.0.0.0", port=7860)
164
+ EOF
165
+
166
+ # ---------------- Start ----------------
167
+ RUN cat <<'EOF' > /start.sh
168
+ #!/bin/bash
169
+ python3 /app.py
170
+ EOF
171
+
172
+ RUN chmod +x /start.sh
173
+
174
+ EXPOSE 7860
175
+ ENTRYPOINT ["/bin/bash", "/start.sh"]