guydffdsdsfd commited on
Commit
fb45ad3
·
verified ·
1 Parent(s): b82afe8

Update Dockerfile

Browse files
Files changed (1) hide show
  1. Dockerfile +11 -12
Dockerfile CHANGED
@@ -12,7 +12,7 @@ RUN pip install --no-cache-dir \
12
  torch \
13
  torchvision \
14
  torchaudio \
15
- diffusers \
16
  transformers \
17
  accelerate \
18
  safetensors \
@@ -29,12 +29,12 @@ ENV DIFFUSERS_CACHE=/home/sd/.cache
29
  # ---------------- Storage ----------------
30
  RUN mkdir -p /home/sd && chmod -R 777 /home/sd
31
 
32
- # ---------------- Image Guard API ----------------
33
  RUN cat <<'EOF' > /app.py
34
  from flask import Flask, request, jsonify, send_file
35
  from flask_cors import CORS
36
- from diffusers import StableDiffusionPipeline
37
- import torch, os, json, datetime, secrets
38
  from io import BytesIO
39
 
40
  app = Flask(__name__)
@@ -47,7 +47,7 @@ USAGE_PATH = f"{BASE}/usage.json"
47
  LIMITS_PATH = f"{BASE}/limits.json"
48
 
49
  DEFAULT_LIMIT = 500
50
- MODEL_ID = "stable-diffusion-v1-5"
51
 
52
  # -------- Init storage --------
53
  os.makedirs(BASE, exist_ok=True)
@@ -56,10 +56,11 @@ for p in [WL_PATH, USAGE_PATH, LIMITS_PATH]:
56
  open(p, "w").write("{}" if p.endswith(".json") else "")
57
 
58
  # -------- Load model once --------
59
- pipe = StableDiffusionPipeline.from_pretrained(
60
  MODEL_ID,
61
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
62
  )
 
63
  pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
64
 
65
  # -------- Helpers --------
@@ -78,7 +79,7 @@ def save_json(path, data):
78
  # -------- Health --------
79
  @app.route("/", methods=["GET"])
80
  def health():
81
- return "Image API Running", 200
82
 
83
  # -------- Key generator --------
84
  @app.route("/generate-key", methods=["POST"])
@@ -96,10 +97,7 @@ def generate_key():
96
  limits[key] = "unlimited" if unlimited else int(limit)
97
  save_json(LIMITS_PATH, limits)
98
 
99
- return jsonify({
100
- "key": key,
101
- "limit": limits[key]
102
- })
103
 
104
  # -------- Image generation --------
105
  @app.route("/api/generate", methods=["POST"])
@@ -122,7 +120,8 @@ def generate():
122
  limit = limits.get(key, DEFAULT_LIMIT)
123
  unlimited = (limit == "unlimited")
124
 
125
- month = datetime.datetime.now().strftime("%Y-%m")
 
126
  used = usage.get(key, {}).get(month, 0)
127
 
128
  if not unlimited and used >= limit:
 
12
  torch \
13
  torchvision \
14
  torchaudio \
15
+ diffusers["torch"] \
16
  transformers \
17
  accelerate \
18
  safetensors \
 
29
  # ---------------- Storage ----------------
30
  RUN mkdir -p /home/sd && chmod -R 777 /home/sd
31
 
32
+ # ---------------- Flask API ----------------
33
  RUN cat <<'EOF' > /app.py
34
  from flask import Flask, request, jsonify, send_file
35
  from flask_cors import CORS
36
+ from diffusers import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, DPMSolverMultistepScheduler
37
+ import torch, os, json, secrets
38
  from io import BytesIO
39
 
40
  app = Flask(__name__)
 
47
  LIMITS_PATH = f"{BASE}/limits.json"
48
 
49
  DEFAULT_LIMIT = 500
50
+ MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
51
 
52
  # -------- Init storage --------
53
  os.makedirs(BASE, exist_ok=True)
 
56
  open(p, "w").write("{}" if p.endswith(".json") else "")
57
 
58
  # -------- Load model once --------
59
+ pipe = StableDiffusionXLPipeline.from_pretrained(
60
  MODEL_ID,
61
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
62
  )
63
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
64
  pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
65
 
66
  # -------- Helpers --------
 
79
  # -------- Health --------
80
  @app.route("/", methods=["GET"])
81
  def health():
82
+ return "Image XL API Running", 200
83
 
84
  # -------- Key generator --------
85
  @app.route("/generate-key", methods=["POST"])
 
97
  limits[key] = "unlimited" if unlimited else int(limit)
98
  save_json(LIMITS_PATH, limits)
99
 
100
+ return jsonify({"key": key, "limit": limits[key]})
 
 
 
101
 
102
  # -------- Image generation --------
103
  @app.route("/api/generate", methods=["POST"])
 
120
  limit = limits.get(key, DEFAULT_LIMIT)
121
  unlimited = (limit == "unlimited")
122
 
123
+ from datetime import datetime
124
+ month = datetime.now().strftime("%Y-%m")
125
  used = usage.get(key, {}).get(month, 0)
126
 
127
  if not unlimited and used >= limit: