guydffdsdsfd commited on
Commit
b82afe8
·
verified ·
1 Parent(s): 622a459

Update Dockerfile

Browse files
Files changed (1) hide show
  1. Dockerfile +97 -73
Dockerfile CHANGED
@@ -1,67 +1,88 @@
1
- FROM ollama/ollama:latest
2
-
3
- # ---------------- System + Python ----------------
4
- RUN apt-get update && apt-get install -y python3 python3-pip && \
5
- pip3 install flask flask-cors requests --break-system-packages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # ---------------- Env ----------------
8
- ENV OLLAMA_HOST=127.0.0.1:11434
9
- ENV OLLAMA_MODELS=/home/ollama/.ollama/models
10
- ENV HOME=/home/ollama
 
11
 
12
  # ---------------- Storage ----------------
13
- RUN mkdir -p /home/ollama/.ollama && chmod -R 777 /home/ollama
14
 
15
- # ---------------- Guard API ----------------
16
- RUN cat <<'EOF' > /guard.py
17
- from flask import Flask, request, Response, jsonify, stream_with_context
18
  from flask_cors import CORS
19
- import requests, json, os, datetime, secrets
 
 
20
 
21
  app = Flask(__name__)
22
  CORS(app)
23
 
24
- DB_PATH = "/home/ollama/usage.json"
25
- WL_PATH = "/home/ollama/whitelist.txt"
26
- KEY_LIMITS_PATH = "/home/ollama/key_limits.json"
27
- DEFAULT_LIMIT = 500
28
-
29
- # ---------- Init ----------
30
- os.makedirs("/home/ollama", exist_ok=True)
31
-
32
- if not os.path.exists(WL_PATH):
33
- open(WL_PATH, "w").close()
34
-
35
- def load_limits():
36
- if os.path.exists(KEY_LIMITS_PATH):
37
- try:
38
- return json.load(open(KEY_LIMITS_PATH))
39
- except:
40
- pass
41
- return {}
42
-
43
- def save_limits(data):
44
- json.dump(data, open(KEY_LIMITS_PATH, "w"))
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def whitelist():
47
  return set(open(WL_PATH).read().split())
48
 
49
- # ---------- Health ----------
50
- @app.route("/", methods=["GET"])
51
- def health():
52
- return "Ollama Guard Running", 200
53
-
54
- @app.route("/api/tags", methods=["GET"])
55
- def tags():
56
  try:
57
- r = requests.get("http://127.0.0.1:11434/api/tags")
58
- return Response(r.content, r.status_code, content_type=r.headers.get("Content-Type"))
59
  except:
60
- return jsonify({"error": "Ollama starting"}), 503
61
 
62
- # ---------- Key Generator ----------
 
 
 
 
 
 
 
 
63
  @app.route("/generate-key", methods=["POST"])
64
- def gen_key():
65
  data = request.get_json() or {}
66
  unlimited = data.get("unlimited", False)
67
  limit = data.get("limit", DEFAULT_LIMIT)
@@ -71,65 +92,68 @@ def gen_key():
71
  with open(WL_PATH, "a") as f:
72
  f.write(key + "\n")
73
 
74
- limits = load_limits()
75
  limits[key] = "unlimited" if unlimited else int(limit)
76
- save_limits(limits)
77
 
78
  return jsonify({
79
  "key": key,
80
  "limit": limits[key]
81
  })
82
 
83
- # ---------- Proxy ----------
84
  @app.route("/api/generate", methods=["POST"])
85
- @app.route("/api/chat", methods=["POST"])
86
- def proxy():
87
  key = request.headers.get("x-api-key", "")
88
  if key not in whitelist():
89
  return jsonify({"error": "Unauthorized"}), 401
90
 
91
- limits = load_limits()
 
 
 
 
 
 
 
 
 
 
92
  limit = limits.get(key, DEFAULT_LIMIT)
93
  unlimited = (limit == "unlimited")
94
 
95
- now = datetime.datetime.now().strftime("%Y-%m")
96
- usage = json.load(open(DB_PATH)) if os.path.exists(DB_PATH) else {}
97
- used = usage.get(key, {}).get(now, 0)
98
 
99
  if not unlimited and used >= limit:
100
  return jsonify({"error": "Monthly limit reached"}), 429
101
 
102
- target = "http://127.0.0.1:11434" + request.path
103
- resp = requests.post(target, json=request.json, stream=True, timeout=300)
 
 
 
104
 
105
- if resp.status_code != 200:
106
- return jsonify({"error": resp.text}), resp.status_code
107
 
108
- usage.setdefault(key, {})[now] = used + 1
109
- json.dump(usage, open(DB_PATH, "w"))
 
110
 
111
- def stream():
112
- for c in resp.iter_content(1024):
113
- if c: yield c
114
-
115
- return Response(stream_with_context(stream()), content_type=resp.headers.get("Content-Type"))
116
 
117
  if __name__ == "__main__":
118
  app.run(host="0.0.0.0", port=7860)
119
  EOF
120
 
121
- # ---------------- Start Script ----------------
122
  RUN cat <<'EOF' > /start.sh
123
  #!/bin/bash
124
- ollama serve &
125
- python3 /guard.py &
126
- sleep 5
127
- ollama pull llama3.2
128
- wait
129
  EOF
130
 
131
  RUN chmod +x /start.sh
132
 
133
  EXPOSE 7860
134
-
135
  ENTRYPOINT ["/bin/bash", "/start.sh"]
 
1
+ FROM python:3.10-slim
2
+
3
+ # ---------------- System deps ----------------
4
+ RUN apt-get update && apt-get install -y \
5
+ git \
6
+ libgl1 \
7
+ libglib2.0-0 \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # ---------------- Python deps ----------------
11
+ RUN pip install --no-cache-dir \
12
+ torch \
13
+ torchvision \
14
+ torchaudio \
15
+ diffusers \
16
+ transformers \
17
+ accelerate \
18
+ safetensors \
19
+ flask \
20
+ flask-cors \
21
+ pillow
22
 
23
  # ---------------- Env ----------------
24
+ ENV HOME=/home/sd
25
+ ENV HF_HOME=/home/sd/.cache
26
+ ENV TRANSFORMERS_CACHE=/home/sd/.cache
27
+ ENV DIFFUSERS_CACHE=/home/sd/.cache
28
 
29
  # ---------------- Storage ----------------
30
+ RUN mkdir -p /home/sd && chmod -R 777 /home/sd
31
 
32
+ # ---------------- Image Guard API ----------------
33
+ RUN cat <<'EOF' > /app.py
34
+ from flask import Flask, request, jsonify, send_file
35
  from flask_cors import CORS
36
+ from diffusers import StableDiffusionPipeline
37
+ import torch, os, json, datetime, secrets
38
+ from io import BytesIO
39
 
40
  app = Flask(__name__)
41
  CORS(app)
42
 
43
+ # -------- Paths --------
44
+ BASE = "/home/sd"
45
+ WL_PATH = f"{BASE}/whitelist.txt"
46
+ USAGE_PATH = f"{BASE}/usage.json"
47
+ LIMITS_PATH = f"{BASE}/limits.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ DEFAULT_LIMIT = 500
50
+ MODEL_ID = "stable-diffusion-v1-5"
51
+
52
+ # -------- Init storage --------
53
+ os.makedirs(BASE, exist_ok=True)
54
+ for p in [WL_PATH, USAGE_PATH, LIMITS_PATH]:
55
+ if not os.path.exists(p):
56
+ open(p, "w").write("{}" if p.endswith(".json") else "")
57
+
58
+ # -------- Load model once --------
59
+ pipe = StableDiffusionPipeline.from_pretrained(
60
+ MODEL_ID,
61
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
62
+ )
63
+ pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
64
+
65
+ # -------- Helpers --------
66
  def whitelist():
67
  return set(open(WL_PATH).read().split())
68
 
69
+ def load_json(path):
 
 
 
 
 
 
70
  try:
71
+ return json.load(open(path))
 
72
  except:
73
+ return {}
74
 
75
+ def save_json(path, data):
76
+ json.dump(data, open(path, "w"))
77
+
78
+ # -------- Health --------
79
+ @app.route("/", methods=["GET"])
80
+ def health():
81
+ return "Image API Running", 200
82
+
83
+ # -------- Key generator --------
84
  @app.route("/generate-key", methods=["POST"])
85
+ def generate_key():
86
  data = request.get_json() or {}
87
  unlimited = data.get("unlimited", False)
88
  limit = data.get("limit", DEFAULT_LIMIT)
 
92
  with open(WL_PATH, "a") as f:
93
  f.write(key + "\n")
94
 
95
+ limits = load_json(LIMITS_PATH)
96
  limits[key] = "unlimited" if unlimited else int(limit)
97
+ save_json(LIMITS_PATH, limits)
98
 
99
  return jsonify({
100
  "key": key,
101
  "limit": limits[key]
102
  })
103
 
104
+ # -------- Image generation --------
105
  @app.route("/api/generate", methods=["POST"])
106
+ def generate():
 
107
  key = request.headers.get("x-api-key", "")
108
  if key not in whitelist():
109
  return jsonify({"error": "Unauthorized"}), 401
110
 
111
+ data = request.get_json() or {}
112
+ prompt = data.get("prompt", "").strip()
113
+ steps = int(data.get("steps", 25))
114
+ guidance = float(data.get("guidance", 7.5))
115
+
116
+ if not prompt:
117
+ return jsonify({"error": "Prompt required"}), 400
118
+
119
+ limits = load_json(LIMITS_PATH)
120
+ usage = load_json(USAGE_PATH)
121
+
122
  limit = limits.get(key, DEFAULT_LIMIT)
123
  unlimited = (limit == "unlimited")
124
 
125
+ month = datetime.datetime.now().strftime("%Y-%m")
126
+ used = usage.get(key, {}).get(month, 0)
 
127
 
128
  if not unlimited and used >= limit:
129
  return jsonify({"error": "Monthly limit reached"}), 429
130
 
131
+ image = pipe(
132
+ prompt=prompt,
133
+ num_inference_steps=steps,
134
+ guidance_scale=guidance
135
+ ).images[0]
136
 
137
+ usage.setdefault(key, {})[month] = used + 1
138
+ save_json(USAGE_PATH, usage)
139
 
140
+ buf = BytesIO()
141
+ image.save(buf, format="PNG")
142
+ buf.seek(0)
143
 
144
+ return send_file(buf, mimetype="image/png")
 
 
 
 
145
 
146
  if __name__ == "__main__":
147
  app.run(host="0.0.0.0", port=7860)
148
  EOF
149
 
150
+ # ---------------- Start ----------------
151
  RUN cat <<'EOF' > /start.sh
152
  #!/bin/bash
153
+ python3 /app.py
 
 
 
 
154
  EOF
155
 
156
  RUN chmod +x /start.sh
157
 
158
  EXPOSE 7860
 
159
  ENTRYPOINT ["/bin/bash", "/start.sh"]