guydffdsdsfd commited on
Commit
751d4cf
·
verified ·
1 Parent(s): 53ce083

Update Dockerfile

Browse files
Files changed (1) hide show
  1. Dockerfile +217 -140
Dockerfile CHANGED
@@ -1,154 +1,231 @@
1
- FROM ollama/ollama:latest
2
-
3
- # Install Python & Dependencies
4
- RUN apt-get update && apt-get install -y python3 python3-pip && \
5
- pip3 install flask flask-cors requests --break-system-packages
6
-
7
- # Set up environment variables
8
- ENV OLLAMA_HOST=127.0.0.1:11434
9
- ENV OLLAMA_MODELS=/home/ollama/.ollama/models
10
- ENV HOME=/home/ollama
11
-
12
- # Create writable directories
13
- RUN mkdir -p /home/ollama/.ollama && chmod -R 777 /home/ollama
14
-
15
- # --- COMPLETE Flask Guard Script (with whitelist endpoint) ---
16
- RUN cat <<'EOF' > /guard.py
17
- from flask import Flask, request, Response, jsonify, stream_with_context
18
- import requests
19
  from flask_cors import CORS
20
- import json, os, datetime, time, threading
 
 
21
 
22
  app = Flask(__name__)
23
- CORS(app)
24
-
25
- DB_PATH = "/home/ollama/usage.json"
26
- WL_PATH = "/home/ollama/whitelist.txt"
27
- LIMIT = 500
28
- UNLIMITED_KEY = "sk-ess4l0ri37"
29
 
30
- # Ensure whitelist exists
31
- if not os.path.exists(WL_PATH):
32
- with open(WL_PATH, "w") as f:
33
- f.write(f"sk-admin-seed-99\nsk-ljlubs0boej\n{UNLIMITED_KEY}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # CRITICAL: Whitelist Management Endpoint (was missing!)
36
- @app.route("/whitelist", methods=["POST"])
37
- def whitelist_key():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  try:
39
- data = request.get_json()
40
- key = data.get("key", "").strip()
41
- if not key:
42
- return jsonify({"error": "No key provided"}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Add key to whitelist
45
- with open(WL_PATH, "a") as f:
46
- f.write(f"{key}\n")
47
- return jsonify({"message": "Key whitelisted successfully"}), 200
48
  except Exception as e:
49
- return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Health Check
52
- @app.route("/", methods=["GET"])
53
  def health():
54
- return "Ollama Proxy is Running", 200
55
-
56
- # API Tags endpoint for health checks
57
- @app.route("/api/tags", methods=["GET"])
58
- def tags():
59
- try:
60
- resp = requests.get("http://127.0.0.1:11434/api/tags")
61
- return Response(resp.content, status=resp.status_code, content_type=resp.headers.get('Content-Type'))
62
- except:
63
- return jsonify({"error": "Ollama starting"}), 503
64
-
65
- def get_whitelist():
66
- try:
67
- with open(WL_PATH, "r") as f:
68
- return set(line.strip() for line in f.readlines())
69
- except:
70
- return set([UNLIMITED_KEY])
71
-
72
- @app.route("/api/generate", methods=["POST"])
73
- @app.route("/api/chat", methods=["POST"])
74
- def proxy():
75
- user_key = request.headers.get("x-api-key", "")
 
76
 
77
- # 1. Auth Check
78
- if user_key not in get_whitelist():
79
- return jsonify({"error": "Unauthorized: Key not registered"}), 401
80
-
81
- # 2. Usage Check
82
- is_unlimited = (user_key == UNLIMITED_KEY)
83
- if not is_unlimited:
84
- now = datetime.datetime.now()
85
- month_key = now.strftime("%Y-%m")
86
- usage = {}
87
- if os.path.exists(DB_PATH):
88
- try:
89
- with open(DB_PATH, "r") as f:
90
- usage = json.load(f)
91
- except:
92
- usage = {}
93
- key_usage = usage.get(user_key, {}).get(month_key, 0)
94
- if key_usage >= LIMIT:
95
- return jsonify({"error": f"Monthly limit of {LIMIT} reached"}), 429
96
-
97
- # 3. Proxy to Ollama
98
  try:
99
- target_url = "http://127.0.0.1:11434" + request.path
 
100
 
101
- resp = requests.post(target_url, json=request.json, stream=True, timeout=300)
102
-
103
- if resp.status_code == 404:
104
- return jsonify({"error": "Model is loading (First run takes ~2 mins). Please wait."}), 503
105
-
106
- if resp.status_code != 200:
107
- return jsonify({"error": f"Ollama Error: {resp.text}"}), resp.status_code
108
-
109
- # Log usage
110
- if not is_unlimited:
111
- if user_key not in usage: usage[user_key] = {}
112
- usage[user_key][month_key] = key_usage + 1
113
- with open(DB_PATH, "w") as f:
114
- json.dump(usage, f)
115
-
116
- # Stream response
117
- def generate():
118
- for chunk in resp.iter_content(chunk_size=1024):
119
- if chunk: yield chunk
120
-
121
- return Response(stream_with_context(generate()), content_type=resp.headers.get('Content-Type'))
122
-
123
- except requests.exceptions.ConnectionError:
124
- return jsonify({"error": "Ollama is starting up. Please wait..."}), 503
125
  except Exception as e:
126
- return jsonify({"error": f"Proxy Error: {str(e)}"}), 500
127
-
128
- if __name__ == "__main__":
129
- app.run(host="0.0.0.0", port=7860)
130
- EOF
131
-
132
- # --- Startup Script ---
133
- RUN cat <<'EOF' > /start.sh
134
- #!/bin/bash
135
- # Start Ollama in the background
136
- ollama serve &
137
-
138
- # Start the Python Guard (Opens Port 7860 immediately for HF)
139
- python3 /guard.py &
140
-
141
- # Wait for Ollama to wake up, then pull the model
142
- sleep 5
143
- echo "Starting Model Pull..."
144
- ollama pull llama3.2:3b
145
- echo "Model Pull Complete."
146
-
147
- # Keep container running
148
- wait
149
- EOF
150
-
151
- RUN chmod +x /start.sh
152
-
153
- # --- Entrypoint ---
154
- ENTRYPOINT ["/bin/bash", "/start.sh"]
 
 
1
+ from flask import Flask, request, jsonify, send_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from flask_cors import CORS
3
+ import os, torch, io, time, json
4
+ from diffusers import StableDiffusionPipeline
5
+ import threading
6
 
7
  app = Flask(__name__)
 
 
 
 
 
 
8
 
9
+ # CORS configuration
10
+ CORS(app, resources={
11
+ r"/*": {
12
+ "origins": [
13
+ "https://kaigpt.vercel.app",
14
+ "https://kaigpt.vercel.app/chat",
15
+ "http://localhost:3000",
16
+ "*"
17
+ ],
18
+ "methods": ["GET", "POST", "OPTIONS"],
19
+ "allow_headers": ["Content-Type", "Authorization", "x-api-key"]
20
+ }
21
+ })
22
+
23
+ # Configuration
24
+ WL_PATH = 'whitelist.txt'
25
+ UNLIMITED_KEY = 'sk-ess4l0ri37'
26
+ TRUSTED_DOMAINS = ["kaigpt.vercel.app", "localhost"]
27
+
28
+ # Global progress tracking
29
+ image_progress = {}
30
+ progress_lock = threading.Lock()
31
+
32
+ print('Loading Stable Diffusion v1.5...')
33
+ try:
34
+ pipe = StableDiffusionPipeline.from_pretrained(
35
+ 'runwayml/stable-diffusion-v1-5',
36
+ torch_dtype=torch.float32,
37
+ safety_checker=None,
38
+ requires_safety_checker=False
39
+ ).to('cpu')
40
+ print('✅ Stable Diffusion loaded successfully')
41
+ except Exception as e:
42
+ print(f'❌ Error loading Stable Diffusion: {e}')
43
+ pipe = None
44
 
45
+ def get_whitelist():
46
+ """Get whitelisted API keys"""
47
+ if not os.path.exists(WL_PATH):
48
+ return {UNLIMITED_KEY}
49
+ with open(WL_PATH, 'r') as f:
50
+ return set(line.strip() for line in f.readlines() if line.strip())
51
+
52
+ def is_trusted_origin():
53
+ """Check if request comes from trusted origin"""
54
+ origin = request.headers.get("Origin", "")
55
+ referer = request.headers.get("Referer", "")
56
+
57
+ for domain in TRUSTED_DOMAINS:
58
+ if domain in origin or domain in referer:
59
+ return True
60
+ return False
61
+
62
+ def update_progress(request_id, progress, status):
63
+ """Update progress for a request"""
64
+ with progress_lock:
65
+ image_progress[request_id] = {
66
+ 'progress': progress,
67
+ 'status': status,
68
+ 'timestamp': time.time()
69
+ }
70
+
71
+ def cleanup_old_progress():
72
+ """Remove old progress entries"""
73
+ with progress_lock:
74
+ current_time = time.time()
75
+ to_remove = []
76
+ for req_id, data in image_progress.items():
77
+ if current_time - data['timestamp'] > 300: # 5 minutes
78
+ to_remove.append(req_id)
79
+
80
+ for req_id in to_remove:
81
+ del image_progress[req_id]
82
+
83
+ @app.route('/api/txt2img', methods=['POST', 'OPTIONS'])
84
+ def gen_img():
85
+ """Generate image from text prompt"""
86
+ if request.method == 'OPTIONS':
87
+ return jsonify({'status': 'ok'}), 200
88
+
89
+ # Check authorization
90
+ if not is_trusted_origin():
91
+ api_key = request.headers.get('x-api-key') or request.json.get('api_key', '')
92
+ if api_key not in get_whitelist():
93
+ return jsonify({'error': 'Unauthorized', 'message': 'Invalid API key'}), 401
94
+
95
+ if not pipe:
96
+ return jsonify({'error': 'Model not loaded', 'message': 'Stable Diffusion is not available'}), 503
97
+
98
+ data = request.get_json(force=True) or {}
99
+ prompt = data.get('prompt', 'a beautiful landscape')
100
+ steps = min(max(int(data.get('steps', 25)), 10), 50) # Clamp between 10-50
101
+ request_id = data.get('request_id', f'img_{int(time.time())}_{hash(prompt) % 10000}')
102
+
103
+ # Clean up old progress entries
104
+ cleanup_old_progress()
105
+
106
+ # Initialize progress
107
+ update_progress(request_id, 0, 'Starting image generation...')
108
+
109
  try:
110
+ # Define progress callback
111
+ def progress_callback(step, timestep, latents):
112
+ progress = int((step / steps) * 100)
113
+ update_progress(request_id, progress, f'Step {step}/{steps}')
114
+
115
+ # Generate image
116
+ print(f'Generating image: "{prompt[:50]}..." ({steps} steps)')
117
+
118
+ with torch.no_grad():
119
+ image = pipe(
120
+ prompt,
121
+ num_inference_steps=steps,
122
+ guidance_scale=7.5,
123
+ callback=progress_callback,
124
+ callback_steps=1
125
+ ).images[0]
126
+
127
+ # Convert to bytes
128
+ img_io = io.BytesIO()
129
+ image.save(img_io, 'PNG', quality=95)
130
+ img_io.seek(0)
131
+
132
+ # Mark as complete
133
+ update_progress(request_id, 100, 'Complete!')
134
+
135
+ # Return image
136
+ return send_file(
137
+ img_io,
138
+ mimetype='image/png',
139
+ as_attachment=False,
140
+ download_name=f'generated_{int(time.time())}.png'
141
+ )
142
 
 
 
 
 
143
  except Exception as e:
144
+ print(f'Image generation error: {e}')
145
+ update_progress(request_id, 0, f'Error: {str(e)}')
146
+ return jsonify({
147
+ 'error': 'Generation failed',
148
+ 'message': str(e),
149
+ 'request_id': request_id
150
+ }), 500
151
+
152
+ @app.route('/api/img_progress/<request_id>', methods=['GET'])
153
+ def get_progress(request_id):
154
+ """Get progress of image generation"""
155
+ cleanup_old_progress()
156
+
157
+ with progress_lock:
158
+ progress_data = image_progress.get(request_id, {
159
+ 'progress': 0,
160
+ 'status': 'Not found or expired',
161
+ 'timestamp': time.time()
162
+ })
163
+
164
+ return jsonify(progress_data)
165
 
166
+ @app.route('/api/health', methods=['GET'])
 
167
  def health():
168
+ """Health check endpoint"""
169
+ status = {
170
+ 'status': 'online' if pipe else 'offline',
171
+ 'model': 'stable-diffusion-v1-5',
172
+ 'loaded': pipe is not None,
173
+ 'trusted_domains': TRUSTED_DOMAINS,
174
+ 'timestamp': time.time()
175
+ }
176
+ return jsonify(status)
177
+
178
+ @app.route('/api/whitelist/add', methods=['POST'])
179
+ def add_to_whitelist():
180
+ """Add API key to whitelist (admin only)"""
181
+ data = request.get_json() or {}
182
+ admin_key = data.get('admin_key', '')
183
+ new_key = data.get('key', '').strip()
184
+
185
+ # Simple admin check - in production use proper authentication
186
+ if admin_key != 'admin123':
187
+ return jsonify({'error': 'Invalid admin key'}), 403
188
+
189
+ if not new_key:
190
+ return jsonify({'error': 'No key provided'}), 400
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  try:
193
+ with open(WL_PATH, 'a') as f:
194
+ f.write(f"{new_key}\n")
195
 
196
+ return jsonify({
197
+ 'status': 'success',
198
+ 'message': f'Key added to whitelist',
199
+ 'total_keys': len(get_whitelist())
200
+ }), 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  except Exception as e:
202
+ return jsonify({'error': str(e)}), 500
203
+
204
+ @app.route('/api/config', methods=['GET'])
205
+ def get_config():
206
+ """Get server configuration"""
207
+ return jsonify({
208
+ 'max_steps': 50,
209
+ 'min_steps': 10,
210
+ 'default_steps': 25,
211
+ 'supported_sizes': ['512x512', '768x768'],
212
+ 'model': 'stable-diffusion-v1-5'
213
+ })
214
+
215
+ if __name__ == '__main__':
216
+ # Create whitelist file if it doesn't exist
217
+ if not os.path.exists(WL_PATH):
218
+ with open(WL_PATH, 'w') as f:
219
+ f.write(f"{UNLIMITED_KEY}\n")
220
+ print(f'Created whitelist file with default key: {UNLIMITED_KEY}')
221
+
222
+ print(f'Whitelisted keys: {get_whitelist()}')
223
+ print(f'Trusted domains: {TRUSTED_DOMAINS}')
224
+ print(f'Server starting on port 7860...')
225
+
226
+ app.run(
227
+ host='0.0.0.0',
228
+ port=7860,
229
+ debug=False,
230
+ threaded=True
231
+ )