guydffdsdsfd commited on
Commit
d298b41
·
verified ·
1 Parent(s): 751d4cf

Update Dockerfile

Browse files
Files changed (1) hide show
  1. Dockerfile +140 -217
Dockerfile CHANGED
@@ -1,231 +1,154 @@
1
- from flask import Flask, request, jsonify, send_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from flask_cors import CORS
3
- import os, torch, io, time, json
4
- from diffusers import StableDiffusionPipeline
5
- import threading
6
 
7
  app = Flask(__name__)
 
8
 
9
- # CORS configuration
10
- CORS(app, resources={
11
- r"/*": {
12
- "origins": [
13
- "https://kaigpt.vercel.app",
14
- "https://kaigpt.vercel.app/chat",
15
- "http://localhost:3000",
16
- "*"
17
- ],
18
- "methods": ["GET", "POST", "OPTIONS"],
19
- "allow_headers": ["Content-Type", "Authorization", "x-api-key"]
20
- }
21
- })
22
-
23
- # Configuration
24
- WL_PATH = 'whitelist.txt'
25
- UNLIMITED_KEY = 'sk-ess4l0ri37'
26
- TRUSTED_DOMAINS = ["kaigpt.vercel.app", "localhost"]
27
-
28
- # Global progress tracking
29
- image_progress = {}
30
- progress_lock = threading.Lock()
31
-
32
- print('Loading Stable Diffusion v1.5...')
33
- try:
34
- pipe = StableDiffusionPipeline.from_pretrained(
35
- 'runwayml/stable-diffusion-v1-5',
36
- torch_dtype=torch.float32,
37
- safety_checker=None,
38
- requires_safety_checker=False
39
- ).to('cpu')
40
- print('✅ Stable Diffusion loaded successfully')
41
- except Exception as e:
42
- print(f'❌ Error loading Stable Diffusion: {e}')
43
- pipe = None
44
 
45
- def get_whitelist():
46
- """Get whitelisted API keys"""
47
- if not os.path.exists(WL_PATH):
48
- return {UNLIMITED_KEY}
49
- with open(WL_PATH, 'r') as f:
50
- return set(line.strip() for line in f.readlines() if line.strip())
51
-
52
- def is_trusted_origin():
53
- """Check if request comes from trusted origin"""
54
- origin = request.headers.get("Origin", "")
55
- referer = request.headers.get("Referer", "")
56
-
57
- for domain in TRUSTED_DOMAINS:
58
- if domain in origin or domain in referer:
59
- return True
60
- return False
61
-
62
- def update_progress(request_id, progress, status):
63
- """Update progress for a request"""
64
- with progress_lock:
65
- image_progress[request_id] = {
66
- 'progress': progress,
67
- 'status': status,
68
- 'timestamp': time.time()
69
- }
70
-
71
- def cleanup_old_progress():
72
- """Remove old progress entries"""
73
- with progress_lock:
74
- current_time = time.time()
75
- to_remove = []
76
- for req_id, data in image_progress.items():
77
- if current_time - data['timestamp'] > 300: # 5 minutes
78
- to_remove.append(req_id)
79
-
80
- for req_id in to_remove:
81
- del image_progress[req_id]
82
-
83
- @app.route('/api/txt2img', methods=['POST', 'OPTIONS'])
84
- def gen_img():
85
- """Generate image from text prompt"""
86
- if request.method == 'OPTIONS':
87
- return jsonify({'status': 'ok'}), 200
88
-
89
- # Check authorization
90
- if not is_trusted_origin():
91
- api_key = request.headers.get('x-api-key') or request.json.get('api_key', '')
92
- if api_key not in get_whitelist():
93
- return jsonify({'error': 'Unauthorized', 'message': 'Invalid API key'}), 401
94
-
95
- if not pipe:
96
- return jsonify({'error': 'Model not loaded', 'message': 'Stable Diffusion is not available'}), 503
97
-
98
- data = request.get_json(force=True) or {}
99
- prompt = data.get('prompt', 'a beautiful landscape')
100
- steps = min(max(int(data.get('steps', 25)), 10), 50) # Clamp between 10-50
101
- request_id = data.get('request_id', f'img_{int(time.time())}_{hash(prompt) % 10000}')
102
-
103
- # Clean up old progress entries
104
- cleanup_old_progress()
105
-
106
- # Initialize progress
107
- update_progress(request_id, 0, 'Starting image generation...')
108
-
109
  try:
110
- # Define progress callback
111
- def progress_callback(step, timestep, latents):
112
- progress = int((step / steps) * 100)
113
- update_progress(request_id, progress, f'Step {step}/{steps}')
114
-
115
- # Generate image
116
- print(f'Generating image: "{prompt[:50]}..." ({steps} steps)')
117
-
118
- with torch.no_grad():
119
- image = pipe(
120
- prompt,
121
- num_inference_steps=steps,
122
- guidance_scale=7.5,
123
- callback=progress_callback,
124
- callback_steps=1
125
- ).images[0]
126
-
127
- # Convert to bytes
128
- img_io = io.BytesIO()
129
- image.save(img_io, 'PNG', quality=95)
130
- img_io.seek(0)
131
-
132
- # Mark as complete
133
- update_progress(request_id, 100, 'Complete!')
134
-
135
- # Return image
136
- return send_file(
137
- img_io,
138
- mimetype='image/png',
139
- as_attachment=False,
140
- download_name=f'generated_{int(time.time())}.png'
141
- )
142
 
 
 
 
 
143
  except Exception as e:
144
- print(f'Image generation error: {e}')
145
- update_progress(request_id, 0, f'Error: {str(e)}')
146
- return jsonify({
147
- 'error': 'Generation failed',
148
- 'message': str(e),
149
- 'request_id': request_id
150
- }), 500
151
-
152
- @app.route('/api/img_progress/<request_id>', methods=['GET'])
153
- def get_progress(request_id):
154
- """Get progress of image generation"""
155
- cleanup_old_progress()
156
-
157
- with progress_lock:
158
- progress_data = image_progress.get(request_id, {
159
- 'progress': 0,
160
- 'status': 'Not found or expired',
161
- 'timestamp': time.time()
162
- })
163
-
164
- return jsonify(progress_data)
165
 
166
- @app.route('/api/health', methods=['GET'])
 
167
  def health():
168
- """Health check endpoint"""
169
- status = {
170
- 'status': 'online' if pipe else 'offline',
171
- 'model': 'stable-diffusion-v1-5',
172
- 'loaded': pipe is not None,
173
- 'trusted_domains': TRUSTED_DOMAINS,
174
- 'timestamp': time.time()
175
- }
176
- return jsonify(status)
177
-
178
- @app.route('/api/whitelist/add', methods=['POST'])
179
- def add_to_whitelist():
180
- """Add API key to whitelist (admin only)"""
181
- data = request.get_json() or {}
182
- admin_key = data.get('admin_key', '')
183
- new_key = data.get('key', '').strip()
184
-
185
- # Simple admin check - in production use proper authentication
186
- if admin_key != 'admin123':
187
- return jsonify({'error': 'Invalid admin key'}), 403
188
-
189
- if not new_key:
190
- return jsonify({'error': 'No key provided'}), 400
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  try:
193
- with open(WL_PATH, 'a') as f:
194
- f.write(f"{new_key}\n")
195
 
196
- return jsonify({
197
- 'status': 'success',
198
- 'message': f'Key added to whitelist',
199
- 'total_keys': len(get_whitelist())
200
- }), 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  except Exception as e:
202
- return jsonify({'error': str(e)}), 500
203
-
204
- @app.route('/api/config', methods=['GET'])
205
- def get_config():
206
- """Get server configuration"""
207
- return jsonify({
208
- 'max_steps': 50,
209
- 'min_steps': 10,
210
- 'default_steps': 25,
211
- 'supported_sizes': ['512x512', '768x768'],
212
- 'model': 'stable-diffusion-v1-5'
213
- })
214
-
215
- if __name__ == '__main__':
216
- # Create whitelist file if it doesn't exist
217
- if not os.path.exists(WL_PATH):
218
- with open(WL_PATH, 'w') as f:
219
- f.write(f"{UNLIMITED_KEY}\n")
220
- print(f'Created whitelist file with default key: {UNLIMITED_KEY}')
221
-
222
- print(f'Whitelisted keys: {get_whitelist()}')
223
- print(f'Trusted domains: {TRUSTED_DOMAINS}')
224
- print(f'Server starting on port 7860...')
225
-
226
- app.run(
227
- host='0.0.0.0',
228
- port=7860,
229
- debug=False,
230
- threaded=True
231
- )
 
1
+ FROM ollama/ollama:latest
2
+
3
+ # Install Python & Dependencies
4
+ RUN apt-get update && apt-get install -y python3 python3-pip && \
5
+ pip3 install flask flask-cors requests --break-system-packages
6
+
7
+ # Set up environment variables
8
+ ENV OLLAMA_HOST=127.0.0.1:11434
9
+ ENV OLLAMA_MODELS=/home/ollama/.ollama/models
10
+ ENV HOME=/home/ollama
11
+
12
+ # Create writable directories
13
+ RUN mkdir -p /home/ollama/.ollama && chmod -R 777 /home/ollama
14
+
15
+ # --- COMPLETE Flask Guard Script (with whitelist endpoint) ---
16
+ RUN cat <<'EOF' > /guard.py
17
+ from flask import Flask, request, Response, jsonify, stream_with_context
18
+ import requests
19
  from flask_cors import CORS
20
+ import json, os, datetime, time, threading
 
 
21
 
22
  app = Flask(__name__)
23
+ CORS(app)
24
 
25
+ DB_PATH = "/home/ollama/usage.json"
26
+ WL_PATH = "/home/ollama/whitelist.txt"
27
+ LIMIT = 500
28
+ UNLIMITED_KEY = "sk-ess4l0ri37"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # Ensure whitelist exists
31
+ if not os.path.exists(WL_PATH):
32
+ with open(WL_PATH, "w") as f:
33
+ f.write(f"sk-admin-seed-99\nsk-ljlubs0boej\n{UNLIMITED_KEY}\n")
34
+
35
+ # CRITICAL: Whitelist Management Endpoint (was missing!)
36
+ @app.route("/whitelist", methods=["POST"])
37
+ def whitelist_key():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  try:
39
+ data = request.get_json()
40
+ key = data.get("key", "").strip()
41
+ if not key:
42
+ return jsonify({"error": "No key provided"}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # Add key to whitelist
45
+ with open(WL_PATH, "a") as f:
46
+ f.write(f"{key}\n")
47
+ return jsonify({"message": "Key whitelisted successfully"}), 200
48
  except Exception as e:
49
+ return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Health Check
52
+ @app.route("/", methods=["GET"])
53
  def health():
54
+ return "Ollama Proxy is Running", 200
55
+
56
+ # API Tags endpoint for health checks
57
+ @app.route("/api/tags", methods=["GET"])
58
+ def tags():
59
+ try:
60
+ resp = requests.get("http://127.0.0.1:11434/api/tags")
61
+ return Response(resp.content, status=resp.status_code, content_type=resp.headers.get('Content-Type'))
62
+ except:
63
+ return jsonify({"error": "Ollama starting"}), 503
64
+
65
+ def get_whitelist():
66
+ try:
67
+ with open(WL_PATH, "r") as f:
68
+ return set(line.strip() for line in f.readlines())
69
+ except:
70
+ return set([UNLIMITED_KEY])
71
+
72
+ @app.route("/api/generate", methods=["POST"])
73
+ @app.route("/api/chat", methods=["POST"])
74
+ def proxy():
75
+ user_key = request.headers.get("x-api-key", "")
 
76
 
77
+ # 1. Auth Check
78
+ if user_key not in get_whitelist():
79
+ return jsonify({"error": "Unauthorized: Key not registered"}), 401
80
+
81
+ # 2. Usage Check
82
+ is_unlimited = (user_key == UNLIMITED_KEY)
83
+ if not is_unlimited:
84
+ now = datetime.datetime.now()
85
+ month_key = now.strftime("%Y-%m")
86
+ usage = {}
87
+ if os.path.exists(DB_PATH):
88
+ try:
89
+ with open(DB_PATH, "r") as f:
90
+ usage = json.load(f)
91
+ except:
92
+ usage = {}
93
+ key_usage = usage.get(user_key, {}).get(month_key, 0)
94
+ if key_usage >= LIMIT:
95
+ return jsonify({"error": f"Monthly limit of {LIMIT} reached"}), 429
96
+
97
+ # 3. Proxy to Ollama
98
  try:
99
+ target_url = "http://127.0.0.1:11434" + request.path
 
100
 
101
+ resp = requests.post(target_url, json=request.json, stream=True, timeout=300)
102
+
103
+ if resp.status_code == 404:
104
+ return jsonify({"error": "Model is loading (First run takes ~2 mins). Please wait."}), 503
105
+
106
+ if resp.status_code != 200:
107
+ return jsonify({"error": f"Ollama Error: {resp.text}"}), resp.status_code
108
+
109
+ # Log usage
110
+ if not is_unlimited:
111
+ if user_key not in usage: usage[user_key] = {}
112
+ usage[user_key][month_key] = key_usage + 1
113
+ with open(DB_PATH, "w") as f:
114
+ json.dump(usage, f)
115
+
116
+ # Stream response
117
+ def generate():
118
+ for chunk in resp.iter_content(chunk_size=1024):
119
+ if chunk: yield chunk
120
+
121
+ return Response(stream_with_context(generate()), content_type=resp.headers.get('Content-Type'))
122
+
123
+ except requests.exceptions.ConnectionError:
124
+ return jsonify({"error": "Ollama is starting up. Please wait..."}), 503
125
  except Exception as e:
126
+ return jsonify({"error": f"Proxy Error: {str(e)}"}), 500
127
+
128
+ if __name__ == "__main__":
129
+ app.run(host="0.0.0.0", port=7860)
130
+ EOF
131
+
132
+ # --- Startup Script ---
133
+ RUN cat <<'EOF' > /start.sh
134
+ #!/bin/bash
135
+ # Start Ollama in the background
136
+ ollama serve &
137
+
138
+ # Start the Python Guard (Opens Port 7860 immediately for HF)
139
+ python3 /guard.py &
140
+
141
+ # Wait for Ollama to wake up, then pull the model
142
+ sleep 5
143
+ echo "Starting Model Pull..."
144
+ ollama pull llama3.2:3b
145
+ echo "Model Pull Complete."
146
+
147
+ # Keep container running
148
+ wait
149
+ EOF
150
+
151
+ RUN chmod +x /start.sh
152
+
153
+ # --- Entrypoint ---
154
+ ENTRYPOINT ["/bin/bash", "/start.sh"]