guydffdsdsfd commited on
Commit
c697687
·
verified ·
1 Parent(s): 1d9f31c

Update Dockerfile

Browse files
Files changed (1) hide show
  1. Dockerfile +48 -39
Dockerfile CHANGED
@@ -1,18 +1,27 @@
1
  FROM python:3.10-slim
2
 
 
 
3
  # ---------------- System deps ----------------
4
- RUN apt-get update && apt-get install -y \
5
  git \
6
  libgl1 \
7
  libglib2.0-0 \
 
8
  && rm -rf /var/lib/apt/lists/*
9
 
10
  # ---------------- Python deps ----------------
11
- RUN pip install --no-cache-dir --upgrade \
12
- torch \
13
- torchvision \
14
- torchaudio \
15
- diffusers["torch"] \
 
 
 
 
 
 
16
  transformers \
17
  accelerate \
18
  safetensors \
@@ -23,12 +32,14 @@ RUN pip install --no-cache-dir --upgrade \
23
  # ---------------- Env ----------------
24
  ENV HOME=/home/sd
25
  ENV HF_HOME=/home/sd/.cache
26
- # Limit threads to prevent CPU choking
27
- ENV OMP_NUM_THREADS=1
 
 
28
  # ---------------- Storage ----------------
29
  RUN mkdir -p /home/sd && chmod -R 777 /home/sd
30
 
31
- # ---------------- Python Application ----------------
32
  RUN cat <<'EOF' > /app.py
33
  from flask import Flask, request, jsonify, send_file
34
  from flask_cors import CORS
@@ -39,7 +50,6 @@ from io import BytesIO
39
  app = Flask(__name__)
40
  CORS(app)
41
 
42
- # -------- Paths --------
43
  BASE = "/home/sd"
44
  WL_PATH = f"{BASE}/whitelist.txt"
45
  USAGE_PATH = f"{BASE}/usage.json"
@@ -48,31 +58,38 @@ LIMITS_PATH = f"{BASE}/limits.json"
48
  DEFAULT_LIMIT = 500
49
  MODEL_ID = "SimianLuo/LCM_Dreamshaper_v7"
50
 
51
- # -------- Init storage --------
52
  os.makedirs(BASE, exist_ok=True)
53
  for p in [WL_PATH, USAGE_PATH, LIMITS_PATH]:
54
  if not os.path.exists(p):
55
  open(p, "w").write("{}" if p.endswith(".json") else "")
56
 
57
- # -------- Load model --------
58
  print(f"Loading {MODEL_ID}...")
59
- pipe = DiffusionPipeline.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
 
 
60
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
61
 
62
  device = "cuda" if torch.cuda.is_available() else "cpu"
63
  pipe = pipe.to(device)
64
 
65
- # -------- CRITICAL MEMORY FIXES --------
66
- if device == "cpu":
67
- # Slices attention computation into chunks (Saves ~2GB RAM)
68
- pipe.enable_attention_slicing()
69
- # Slices VAE decoding (Saves ~1GB RAM)
70
- pipe.enable_vae_slicing()
71
- print("Memory optimizations enabled for CPU.")
72
 
73
- print(f"Model loaded on {device}")
74
 
75
- # -------- Helpers --------
76
  def whitelist():
77
  try:
78
  return set(open(WL_PATH).read().split())
@@ -88,7 +105,6 @@ def load_json(path):
88
  def save_json(path, data):
89
  json.dump(data, open(path, "w"))
90
 
91
- # -------- Routes --------
92
  @app.route("/", methods=["GET"])
93
  def health():
94
  return "LCM Image API Running", 200
@@ -99,7 +115,7 @@ def generate_key():
99
  key = "sk-" + secrets.token_hex(16)
100
  with open(WL_PATH, "a") as f:
101
  f.write(key + "\n")
102
-
103
  limits = load_json(LIMITS_PATH)
104
  limits[key] = "unlimited" if data.get("unlimited") else int(data.get("limit", DEFAULT_LIMIT))
105
  save_json(LIMITS_PATH, limits)
@@ -113,29 +129,27 @@ def generate():
113
 
114
  data = request.get_json() or {}
115
  prompt = data.get("prompt", "").strip()
116
- if not prompt: return jsonify({"error": "Prompt required"}), 400
 
117
 
118
- # Rate Limiting
119
  limits = load_json(LIMITS_PATH)
120
  usage = load_json(USAGE_PATH)
121
  limit = limits.get(key, DEFAULT_LIMIT)
122
-
123
  from datetime import datetime
124
  month = datetime.now().strftime("%Y-%m")
125
  used = usage.get(key, {}).get(month, 0)
126
-
127
  if limit != "unlimited" and used >= limit:
128
  return jsonify({"error": "Monthly limit reached"}), 429
129
 
130
- # Generate
131
  try:
132
- # Hardcoded for stability
133
  image = pipe(
134
  prompt=prompt,
135
  num_inference_steps=4,
136
  guidance_scale=1.5
137
  ).images[0]
138
-
139
  usage.setdefault(key, {})[month] = used + 1
140
  save_json(USAGE_PATH, usage)
141
 
@@ -143,8 +157,8 @@ def generate():
143
  image.save(buf, format="PNG")
144
  buf.seek(0)
145
  return send_file(buf, mimetype="image/png")
 
146
  except Exception as e:
147
- print(f"Error: {e}")
148
  return jsonify({"error": str(e)}), 500
149
 
150
  if __name__ == "__main__":
@@ -152,12 +166,7 @@ if __name__ == "__main__":
152
  EOF
153
 
154
  # ---------------- Start ----------------
155
- RUN cat <<'EOF' > /start.sh
156
- #!/bin/bash
157
- python3 /app.py
158
- EOF
159
-
160
- RUN chmod +x /start.sh
161
 
162
  EXPOSE 7860
163
- ENTRYPOINT ["/bin/bash", "/start.sh"]
 
1
  FROM python:3.10-slim
2
 
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+
5
  # ---------------- System deps ----------------
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
  git \
8
  libgl1 \
9
  libglib2.0-0 \
10
+ ca-certificates \
11
  && rm -rf /var/lib/apt/lists/*
12
 
13
  # ---------------- Python deps ----------------
14
+ # Install torch FIRST, pinned, CPU by default (much smaller + stable)
15
+ RUN pip install --no-cache-dir --upgrade pip && \
16
+ pip install --no-cache-dir \
17
+ torch==2.1.2 \
18
+ torchvision==0.16.2 \
19
+ torchaudio==2.1.2 \
20
+ --index-url https://download.pytorch.org/whl/cpu
21
+
22
+ # Then the rest
23
+ RUN pip install --no-cache-dir \
24
+ diffusers[torch] \
25
  transformers \
26
  accelerate \
27
  safetensors \
 
32
  # ---------------- Env ----------------
33
  ENV HOME=/home/sd
34
  ENV HF_HOME=/home/sd/.cache
35
+ ENV OMP_NUM_THREADS=1
36
+ ENV MKL_NUM_THREADS=1
37
+ ENV PYTORCH_ENABLE_MPS_FALLBACK=1
38
+
39
  # ---------------- Storage ----------------
40
  RUN mkdir -p /home/sd && chmod -R 777 /home/sd
41
 
42
+ # ---------------- App ----------------
43
  RUN cat <<'EOF' > /app.py
44
  from flask import Flask, request, jsonify, send_file
45
  from flask_cors import CORS
 
50
  app = Flask(__name__)
51
  CORS(app)
52
 
 
53
  BASE = "/home/sd"
54
  WL_PATH = f"{BASE}/whitelist.txt"
55
  USAGE_PATH = f"{BASE}/usage.json"
 
58
  DEFAULT_LIMIT = 500
59
  MODEL_ID = "SimianLuo/LCM_Dreamshaper_v7"
60
 
 
61
  os.makedirs(BASE, exist_ok=True)
62
  for p in [WL_PATH, USAGE_PATH, LIMITS_PATH]:
63
  if not os.path.exists(p):
64
  open(p, "w").write("{}" if p.endswith(".json") else "")
65
 
 
66
  print(f"Loading {MODEL_ID}...")
67
+
68
+ torch.set_grad_enabled(False)
69
+ torch.backends.cuda.matmul.allow_tf32 = True
70
+ torch.backends.cudnn.allow_tf32 = True
71
+
72
+ pipe = DiffusionPipeline.from_pretrained(
73
+ MODEL_ID,
74
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
75
+ safety_checker=None
76
+ )
77
+
78
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
79
 
80
  device = "cuda" if torch.cuda.is_available() else "cpu"
81
  pipe = pipe.to(device)
82
 
83
+ # ---- SPEED + MEMORY OPTS ----
84
+ pipe.enable_attention_slicing()
85
+ pipe.enable_vae_slicing()
86
+
87
+ if device == "cuda":
88
+ pipe.enable_model_cpu_offload()
89
+ pipe.unet.to(memory_format=torch.channels_last)
90
 
91
+ print(f"Model ready on {device}")
92
 
 
93
  def whitelist():
94
  try:
95
  return set(open(WL_PATH).read().split())
 
105
  def save_json(path, data):
106
  json.dump(data, open(path, "w"))
107
 
 
108
  @app.route("/", methods=["GET"])
109
  def health():
110
  return "LCM Image API Running", 200
 
115
  key = "sk-" + secrets.token_hex(16)
116
  with open(WL_PATH, "a") as f:
117
  f.write(key + "\n")
118
+
119
  limits = load_json(LIMITS_PATH)
120
  limits[key] = "unlimited" if data.get("unlimited") else int(data.get("limit", DEFAULT_LIMIT))
121
  save_json(LIMITS_PATH, limits)
 
129
 
130
  data = request.get_json() or {}
131
  prompt = data.get("prompt", "").strip()
132
+ if not prompt:
133
+ return jsonify({"error": "Prompt required"}), 400
134
 
 
135
  limits = load_json(LIMITS_PATH)
136
  usage = load_json(USAGE_PATH)
137
  limit = limits.get(key, DEFAULT_LIMIT)
138
+
139
  from datetime import datetime
140
  month = datetime.now().strftime("%Y-%m")
141
  used = usage.get(key, {}).get(month, 0)
142
+
143
  if limit != "unlimited" and used >= limit:
144
  return jsonify({"error": "Monthly limit reached"}), 429
145
 
 
146
  try:
 
147
  image = pipe(
148
  prompt=prompt,
149
  num_inference_steps=4,
150
  guidance_scale=1.5
151
  ).images[0]
152
+
153
  usage.setdefault(key, {})[month] = used + 1
154
  save_json(USAGE_PATH, usage)
155
 
 
157
  image.save(buf, format="PNG")
158
  buf.seek(0)
159
  return send_file(buf, mimetype="image/png")
160
+
161
  except Exception as e:
 
162
  return jsonify({"error": str(e)}), 500
163
 
164
  if __name__ == "__main__":
 
166
  EOF
167
 
168
  # ---------------- Start ----------------
169
+ RUN echo '#!/bin/bash\npython3 /app.py' > /start.sh && chmod +x /start.sh
 
 
 
 
 
170
 
171
  EXPOSE 7860
172
+ ENTRYPOINT ["/bin/bash", "/start.sh"]