guydffdsdsfd commited on
Commit
1d9f31c
·
verified ·
1 Parent(s): 7521234

Update Dockerfile

Browse files
Files changed (1) hide show
  1. Dockerfile +32 -44
Dockerfile CHANGED
@@ -8,7 +8,6 @@ RUN apt-get update && apt-get install -y \
8
  && rm -rf /var/lib/apt/lists/*
9
 
10
  # ---------------- Python deps ----------------
11
- # We upgrade these to ensure LCM support works
12
  RUN pip install --no-cache-dir --upgrade \
13
  torch \
14
  torchvision \
@@ -24,10 +23,12 @@ RUN pip install --no-cache-dir --upgrade \
24
  # ---------------- Env ----------------
25
  ENV HOME=/home/sd
26
  ENV HF_HOME=/home/sd/.cache
 
 
27
  # ---------------- Storage ----------------
28
  RUN mkdir -p /home/sd && chmod -R 777 /home/sd
29
 
30
- # ---------------- Python Application (Written directly to file) ----------------
31
  RUN cat <<'EOF' > /app.py
32
  from flask import Flask, request, jsonify, send_file
33
  from flask_cors import CORS
@@ -45,9 +46,6 @@ USAGE_PATH = f"{BASE}/usage.json"
45
  LIMITS_PATH = f"{BASE}/limits.json"
46
 
47
  DEFAULT_LIMIT = 500
48
-
49
- # -------- Model Config --------
50
- # LCM Dreamshaper is SD1.5 based (small) and needs only 4-8 steps (fast)
51
  MODEL_ID = "SimianLuo/LCM_Dreamshaper_v7"
52
 
53
  # -------- Init storage --------
@@ -56,16 +54,22 @@ for p in [WL_PATH, USAGE_PATH, LIMITS_PATH]:
56
  if not os.path.exists(p):
57
  open(p, "w").write("{}" if p.endswith(".json") else "")
58
 
59
- # -------- Load model once --------
60
  print(f"Loading {MODEL_ID}...")
61
  pipe = DiffusionPipeline.from_pretrained(MODEL_ID)
62
-
63
- # Ensure we use the LCM Scheduler (Fixes the IndexError crash)
64
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
65
 
66
- # Detect hardware (CPU vs CUDA)
67
  device = "cuda" if torch.cuda.is_available() else "cpu"
68
  pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
69
  print(f"Model loaded on {device}")
70
 
71
  # -------- Helpers --------
@@ -84,30 +88,23 @@ def load_json(path):
84
  def save_json(path, data):
85
  json.dump(data, open(path, "w"))
86
 
87
- # -------- Health --------
88
  @app.route("/", methods=["GET"])
89
  def health():
90
  return "LCM Image API Running", 200
91
 
92
- # -------- Key generator --------
93
  @app.route("/generate-key", methods=["POST"])
94
  def generate_key():
95
  data = request.get_json() or {}
96
- unlimited = data.get("unlimited", False)
97
- limit = data.get("limit", DEFAULT_LIMIT)
98
-
99
  key = "sk-" + secrets.token_hex(16)
100
-
101
  with open(WL_PATH, "a") as f:
102
  f.write(key + "\n")
103
-
104
  limits = load_json(LIMITS_PATH)
105
- limits[key] = "unlimited" if unlimited else int(limit)
106
  save_json(LIMITS_PATH, limits)
107
-
108
  return jsonify({"key": key, "limit": limits[key]})
109
 
110
- # -------- Image generation --------
111
  @app.route("/api/generate", methods=["POST"])
112
  def generate():
113
  key = request.headers.get("x-api-key", "")
@@ -116,49 +113,40 @@ def generate():
116
 
117
  data = request.get_json() or {}
118
  prompt = data.get("prompt", "").strip()
 
119
 
120
- # LCM Optimization:
121
- # LCM works best between 4 and 8 steps.
122
- # We ignore the user's requested 'steps' to ensure speed and stability.
123
- steps = 4
124
- guidance = 2.0
125
-
126
- if not prompt:
127
- return jsonify({"error": "Prompt required"}), 400
128
-
129
- # Rate Limiting Logic
130
  limits = load_json(LIMITS_PATH)
131
  usage = load_json(USAGE_PATH)
132
  limit = limits.get(key, DEFAULT_LIMIT)
133
- unlimited = (limit == "unlimited")
134
  from datetime import datetime
135
  month = datetime.now().strftime("%Y-%m")
136
  used = usage.get(key, {}).get(month, 0)
137
 
138
- if not unlimited and used >= limit:
139
  return jsonify({"error": "Monthly limit reached"}), 429
140
 
141
  # Generate
142
  try:
 
143
  image = pipe(
144
  prompt=prompt,
145
- num_inference_steps=steps,
146
- guidance_scale=guidance
147
  ).images[0]
 
 
 
 
 
 
 
 
148
  except Exception as e:
149
- print(f"Generation Error: {e}")
150
  return jsonify({"error": str(e)}), 500
151
 
152
- # Save Usage
153
- usage.setdefault(key, {})[month] = used + 1
154
- save_json(USAGE_PATH, usage)
155
-
156
- # Return Image
157
- buf = BytesIO()
158
- image.save(buf, format="PNG")
159
- buf.seek(0)
160
- return send_file(buf, mimetype="image/png")
161
-
162
  if __name__ == "__main__":
163
  app.run(host="0.0.0.0", port=7860)
164
  EOF
 
8
  && rm -rf /var/lib/apt/lists/*
9
 
10
  # ---------------- Python deps ----------------
 
11
  RUN pip install --no-cache-dir --upgrade \
12
  torch \
13
  torchvision \
 
23
  # ---------------- Env ----------------
24
  ENV HOME=/home/sd
25
  ENV HF_HOME=/home/sd/.cache
26
+ # Limit threads to prevent CPU choking
27
+ ENV OMP_NUM_THREADS=1
28
  # ---------------- Storage ----------------
29
  RUN mkdir -p /home/sd && chmod -R 777 /home/sd
30
 
31
+ # ---------------- Python Application ----------------
32
  RUN cat <<'EOF' > /app.py
33
  from flask import Flask, request, jsonify, send_file
34
  from flask_cors import CORS
 
46
  LIMITS_PATH = f"{BASE}/limits.json"
47
 
48
  DEFAULT_LIMIT = 500
 
 
 
49
  MODEL_ID = "SimianLuo/LCM_Dreamshaper_v7"
50
 
51
  # -------- Init storage --------
 
54
  if not os.path.exists(p):
55
  open(p, "w").write("{}" if p.endswith(".json") else "")
56
 
57
+ # -------- Load model --------
58
  print(f"Loading {MODEL_ID}...")
59
  pipe = DiffusionPipeline.from_pretrained(MODEL_ID)
 
 
60
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
61
 
 
62
  device = "cuda" if torch.cuda.is_available() else "cpu"
63
  pipe = pipe.to(device)
64
+
65
+ # -------- CRITICAL MEMORY FIXES --------
66
+ if device == "cpu":
67
+ # Slices attention computation into chunks (Saves ~2GB RAM)
68
+ pipe.enable_attention_slicing()
69
+ # Slices VAE decoding (Saves ~1GB RAM)
70
+ pipe.enable_vae_slicing()
71
+ print("Memory optimizations enabled for CPU.")
72
+
73
  print(f"Model loaded on {device}")
74
 
75
  # -------- Helpers --------
 
88
  def save_json(path, data):
89
  json.dump(data, open(path, "w"))
90
 
91
+ # -------- Routes --------
92
  @app.route("/", methods=["GET"])
93
  def health():
94
  return "LCM Image API Running", 200
95
 
 
96
  @app.route("/generate-key", methods=["POST"])
97
  def generate_key():
98
  data = request.get_json() or {}
 
 
 
99
  key = "sk-" + secrets.token_hex(16)
 
100
  with open(WL_PATH, "a") as f:
101
  f.write(key + "\n")
102
+
103
  limits = load_json(LIMITS_PATH)
104
+ limits[key] = "unlimited" if data.get("unlimited") else int(data.get("limit", DEFAULT_LIMIT))
105
  save_json(LIMITS_PATH, limits)
 
106
  return jsonify({"key": key, "limit": limits[key]})
107
 
 
108
  @app.route("/api/generate", methods=["POST"])
109
  def generate():
110
  key = request.headers.get("x-api-key", "")
 
113
 
114
  data = request.get_json() or {}
115
  prompt = data.get("prompt", "").strip()
116
+ if not prompt: return jsonify({"error": "Prompt required"}), 400
117
 
118
+ # Rate Limiting
 
 
 
 
 
 
 
 
 
119
  limits = load_json(LIMITS_PATH)
120
  usage = load_json(USAGE_PATH)
121
  limit = limits.get(key, DEFAULT_LIMIT)
122
+
123
  from datetime import datetime
124
  month = datetime.now().strftime("%Y-%m")
125
  used = usage.get(key, {}).get(month, 0)
126
 
127
+ if limit != "unlimited" and used >= limit:
128
  return jsonify({"error": "Monthly limit reached"}), 429
129
 
130
  # Generate
131
  try:
132
+ # Hardcoded for stability
133
  image = pipe(
134
  prompt=prompt,
135
+ num_inference_steps=4,
136
+ guidance_scale=1.5
137
  ).images[0]
138
+
139
+ usage.setdefault(key, {})[month] = used + 1
140
+ save_json(USAGE_PATH, usage)
141
+
142
+ buf = BytesIO()
143
+ image.save(buf, format="PNG")
144
+ buf.seek(0)
145
+ return send_file(buf, mimetype="image/png")
146
  except Exception as e:
147
+ print(f"Error: {e}")
148
  return jsonify({"error": str(e)}), 500
149
 
 
 
 
 
 
 
 
 
 
 
150
  if __name__ == "__main__":
151
  app.run(host="0.0.0.0", port=7860)
152
  EOF