guydffdsdsfd commited on
Commit
622a459
·
verified ·
1 Parent(s): 0bd5253

Update Dockerfile

Browse files
Files changed (1) hide show
  1. Dockerfile +107 -65
Dockerfile CHANGED
@@ -1,89 +1,131 @@
1
- FROM python:3.10-slim
2
-
3
- # System deps (torch + diffusers need these or they sulk)
4
- RUN apt-get update && apt-get install -y \
5
- git \
6
- libgl1 \
7
- libglib2.0-0 \
8
- && rm -rf /var/lib/apt/lists/*
9
-
10
- # Python deps
11
- RUN pip install --no-cache-dir \
12
- torch \
13
- torchvision \
14
- torchaudio \
15
- diffusers \
16
- transformers \
17
- accelerate \
18
- safetensors \
19
- flask \
20
- flask-cors \
21
- pillow
22
-
23
- # Environment
24
- ENV HOME=/home/sd
25
- ENV HF_HOME=/home/sd/.cache
26
- ENV TRANSFORMERS_CACHE=/home/sd/.cache
27
- ENV DIFFUSERS_CACHE=/home/sd/.cache
28
-
29
- # Writable dirs (HF Spaces is picky)
30
- RUN mkdir -p /home/sd && chmod -R 777 /home/sd
31
-
32
- # -------- Flask Stable Diffusion API --------
33
- RUN cat <<'EOF' > /app.py
34
- from flask import Flask, request, jsonify, send_file
35
  from flask_cors import CORS
36
- from diffusers import StableDiffusionPipeline
37
- import torch
38
- from io import BytesIO
39
- import os
40
 
41
  app = Flask(__name__)
42
  CORS(app)
43
 
44
- MODEL_ID = "runwayml/stable-diffusion-v1-5"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- pipe = StableDiffusionPipeline.from_pretrained(
47
- MODEL_ID,
48
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
49
- )
50
- pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
51
 
 
 
 
 
52
  @app.route("/", methods=["GET"])
53
  def health():
54
- return "Stable Diffusion API Running", 200
55
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @app.route("/api/generate", methods=["POST"])
57
- def generate():
58
- data = request.get_json()
59
- prompt = data.get("prompt", "")
60
- steps = int(data.get("steps", 25))
61
- guidance = float(data.get("guidance", 7.5))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if not prompt:
64
- return jsonify({"error": "No prompt provided"}), 400
65
 
66
- image = pipe(
67
- prompt=prompt,
68
- num_inference_steps=steps,
69
- guidance_scale=guidance
70
- ).images[0]
71
 
72
- buf = BytesIO()
73
- image.save(buf, format="PNG")
74
- buf.seek(0)
75
 
76
- return send_file(buf, mimetype="image/png")
77
 
78
  if __name__ == "__main__":
79
  app.run(host="0.0.0.0", port=7860)
80
  EOF
81
 
82
- # -------- Startup Script --------
83
  RUN cat <<'EOF' > /start.sh
84
  #!/bin/bash
85
- echo "Starting Stable Diffusion API..."
86
- python3 /app.py
 
 
 
87
  EOF
88
 
89
  RUN chmod +x /start.sh
 
1
+ FROM ollama/ollama:latest
2
+
3
+ # ---------------- System + Python ----------------
4
+ RUN apt-get update && apt-get install -y python3 python3-pip && \
5
+ pip3 install flask flask-cors requests --break-system-packages
6
+
7
+ # ---------------- Env ----------------
8
+ ENV OLLAMA_HOST=127.0.0.1:11434
9
+ ENV OLLAMA_MODELS=/home/ollama/.ollama/models
10
+ ENV HOME=/home/ollama
11
+
12
+ # ---------------- Storage ----------------
13
+ RUN mkdir -p /home/ollama/.ollama && chmod -R 777 /home/ollama
14
+
15
+ # ---------------- Guard API ----------------
16
+ RUN cat <<'EOF' > /guard.py
17
+ from flask import Flask, request, Response, jsonify, stream_with_context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  from flask_cors import CORS
19
+ import requests, json, os, datetime, secrets
 
 
 
20
 
21
  app = Flask(__name__)
22
  CORS(app)
23
 
24
+ DB_PATH = "/home/ollama/usage.json"
25
+ WL_PATH = "/home/ollama/whitelist.txt"
26
+ KEY_LIMITS_PATH = "/home/ollama/key_limits.json"
27
+ DEFAULT_LIMIT = 500
28
+
29
+ # ---------- Init ----------
30
+ os.makedirs("/home/ollama", exist_ok=True)
31
+
32
+ if not os.path.exists(WL_PATH):
33
+ open(WL_PATH, "w").close()
34
+
35
+ def load_limits():
36
+ if os.path.exists(KEY_LIMITS_PATH):
37
+ try:
38
+ return json.load(open(KEY_LIMITS_PATH))
39
+ except:
40
+ pass
41
+ return {}
42
 
43
+ def save_limits(data):
44
+ json.dump(data, open(KEY_LIMITS_PATH, "w"))
 
 
 
45
 
46
+ def whitelist():
47
+ return set(open(WL_PATH).read().split())
48
+
49
+ # ---------- Health ----------
50
  @app.route("/", methods=["GET"])
51
  def health():
52
+ return "Ollama Guard Running", 200
53
+
54
+ @app.route("/api/tags", methods=["GET"])
55
+ def tags():
56
+ try:
57
+ r = requests.get("http://127.0.0.1:11434/api/tags")
58
+ return Response(r.content, r.status_code, content_type=r.headers.get("Content-Type"))
59
+ except:
60
+ return jsonify({"error": "Ollama starting"}), 503
61
+
62
+ # ---------- Key Generator ----------
63
+ @app.route("/generate-key", methods=["POST"])
64
+ def gen_key():
65
+ data = request.get_json() or {}
66
+ unlimited = data.get("unlimited", False)
67
+ limit = data.get("limit", DEFAULT_LIMIT)
68
+
69
+ key = "sk-" + secrets.token_hex(16)
70
+
71
+ with open(WL_PATH, "a") as f:
72
+ f.write(key + "\n")
73
+
74
+ limits = load_limits()
75
+ limits[key] = "unlimited" if unlimited else int(limit)
76
+ save_limits(limits)
77
+
78
+ return jsonify({
79
+ "key": key,
80
+ "limit": limits[key]
81
+ })
82
+
83
+ # ---------- Proxy ----------
84
  @app.route("/api/generate", methods=["POST"])
85
+ @app.route("/api/chat", methods=["POST"])
86
+ def proxy():
87
+ key = request.headers.get("x-api-key", "")
88
+ if key not in whitelist():
89
+ return jsonify({"error": "Unauthorized"}), 401
90
+
91
+ limits = load_limits()
92
+ limit = limits.get(key, DEFAULT_LIMIT)
93
+ unlimited = (limit == "unlimited")
94
+
95
+ now = datetime.datetime.now().strftime("%Y-%m")
96
+ usage = json.load(open(DB_PATH)) if os.path.exists(DB_PATH) else {}
97
+ used = usage.get(key, {}).get(now, 0)
98
+
99
+ if not unlimited and used >= limit:
100
+ return jsonify({"error": "Monthly limit reached"}), 429
101
+
102
+ target = "http://127.0.0.1:11434" + request.path
103
+ resp = requests.post(target, json=request.json, stream=True, timeout=300)
104
 
105
+ if resp.status_code != 200:
106
+ return jsonify({"error": resp.text}), resp.status_code
107
 
108
+ usage.setdefault(key, {})[now] = used + 1
109
+ json.dump(usage, open(DB_PATH, "w"))
 
 
 
110
 
111
+ def stream():
112
+ for c in resp.iter_content(1024):
113
+ if c: yield c
114
 
115
+ return Response(stream_with_context(stream()), content_type=resp.headers.get("Content-Type"))
116
 
117
  if __name__ == "__main__":
118
  app.run(host="0.0.0.0", port=7860)
119
  EOF
120
 
121
+ # ---------------- Start Script ----------------
122
  RUN cat <<'EOF' > /start.sh
123
  #!/bin/bash
124
+ ollama serve &
125
+ python3 /guard.py &
126
+ sleep 5
127
+ ollama pull llama3.2
128
+ wait
129
  EOF
130
 
131
  RUN chmod +x /start.sh