File size: 4,710 Bytes
d298b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e5ee5f
d298b41
1e5ee5f
 
d298b41
ce3749d
d298b41
 
 
 
abf19aa
d298b41
 
 
 
 
 
 
 
abf19aa
d298b41
 
 
 
abf19aa
d298b41
 
 
 
abf19aa
d298b41
ce3749d
d298b41
 
abf19aa
d298b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abf19aa
d298b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e5ee5f
d298b41
7db7a01
d298b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8974663
d298b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b82988
d298b41
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
FROM ollama/ollama:latest

# Install Python & Dependencies
RUN apt-get update && apt-get install -y python3 python3-pip && \
    pip3 install flask flask-cors requests --break-system-packages

# Set up environment variables
ENV OLLAMA_HOST=127.0.0.1:11434
ENV OLLAMA_MODELS=/home/ollama/.ollama/models
ENV HOME=/home/ollama

# Create writable directories
RUN mkdir -p /home/ollama/.ollama && chmod -R 777 /home/ollama

# --- COMPLETE Flask Guard Script (with whitelist endpoint) ---
RUN cat <<'EOF' > /guard.py
from flask import Flask, request, Response, jsonify, stream_with_context
import requests
from flask_cors import CORS
import json, os, datetime, time, threading

app = Flask(__name__)
CORS(app)

DB_PATH = "/home/ollama/usage.json"
WL_PATH = "/home/ollama/whitelist.txt"
LIMIT = 500
UNLIMITED_KEY = "sk-ess4l0ri37"

# Ensure whitelist exists
if not os.path.exists(WL_PATH):
    with open(WL_PATH, "w") as f:
        f.write(f"sk-admin-seed-99\nsk-ljlubs0boej\n{UNLIMITED_KEY}\n")

# CRITICAL: Whitelist Management Endpoint (was missing!)
@app.route("/whitelist", methods=["POST"])
def whitelist_key():
    try:
        data = request.get_json()
        key = data.get("key", "").strip()
        if not key:
            return jsonify({"error": "No key provided"}), 400
        
        # Add key to whitelist
        with open(WL_PATH, "a") as f:
            f.write(f"{key}\n")
        return jsonify({"message": "Key whitelisted successfully"}), 200
    except Exception as e:
        return jsonify({"error": str(e)}), 500

# Health Check
@app.route("/", methods=["GET"])
def health():
    return "Ollama Proxy is Running", 200

# API Tags endpoint for health checks
@app.route("/api/tags", methods=["GET"])
def tags():
    try:
        resp = requests.get("http://127.0.0.1:11434/api/tags")
        return Response(resp.content, status=resp.status_code, content_type=resp.headers.get('Content-Type'))
    except:
        return jsonify({"error": "Ollama starting"}), 503

def get_whitelist():
    try:
        with open(WL_PATH, "r") as f:
            return set(line.strip() for line in f.readlines())
    except:
        return set([UNLIMITED_KEY])

@app.route("/api/generate", methods=["POST"])
@app.route("/api/chat", methods=["POST"])
def proxy():
    user_key = request.headers.get("x-api-key", "")
    
    # 1. Auth Check
    if user_key not in get_whitelist():
        return jsonify({"error": "Unauthorized: Key not registered"}), 401

    # 2. Usage Check
    is_unlimited = (user_key == UNLIMITED_KEY)
    if not is_unlimited:
        now = datetime.datetime.now()
        month_key = now.strftime("%Y-%m")
        usage = {}
        if os.path.exists(DB_PATH):
            try:
                with open(DB_PATH, "r") as f:
                    usage = json.load(f)
            except:
                usage = {}
        key_usage = usage.get(user_key, {}).get(month_key, 0)
        if key_usage >= LIMIT:
            return jsonify({"error": f"Monthly limit of {LIMIT} reached"}), 429

    # 3. Proxy to Ollama
    try:
        target_url = "http://127.0.0.1:11434" + request.path
        
        resp = requests.post(target_url, json=request.json, stream=True, timeout=300)

        if resp.status_code == 404:
            return jsonify({"error": "Model is loading (First run takes ~2 mins). Please wait."}), 503

        if resp.status_code != 200:
            return jsonify({"error": f"Ollama Error: {resp.text}"}), resp.status_code

        # Log usage
        if not is_unlimited:
            if user_key not in usage: usage[user_key] = {}
            usage[user_key][month_key] = key_usage + 1
            with open(DB_PATH, "w") as f:
                json.dump(usage, f)

        # Stream response
        def generate():
            for chunk in resp.iter_content(chunk_size=1024):
                if chunk: yield chunk

        return Response(stream_with_context(generate()), content_type=resp.headers.get('Content-Type'))

    except requests.exceptions.ConnectionError:
        return jsonify({"error": "Ollama is starting up. Please wait..."}), 503
    except Exception as e:
        return jsonify({"error": f"Proxy Error: {str(e)}"}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)
EOF

# --- Startup Script ---
RUN cat <<'EOF' > /start.sh
#!/bin/bash
# Start Ollama in the background
ollama serve &

# Start the Python Guard (Opens Port 7860 immediately for HF)
python3 /guard.py &

# Wait for Ollama to wake up, then pull the model
sleep 5
echo "Starting Model Pull..."
ollama pull llama3.2:1b
echo "Model Pull Complete."

# Keep container running
wait
EOF

RUN chmod +x /start.sh

# --- Entrypoint ---
ENTRYPOINT ["/bin/bash", "/start.sh"]