File size: 9,372 Bytes
64cd325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import os
import sys
import logging
from functools import wraps
from flask import Flask, request, jsonify
import torch
import pandas as pd
from huggingface_hub import hf_hub_download

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Add parent directory to sys.path to allow imports from 'model'
try:
    from model.kronos import Kronos, KronosTokenizer, KronosPredictor
except ImportError as e:
    logging.error(f"Could not import from model.kronos: {e}")
    sys.exit(1)

# --- Globals ---
app = Flask(__name__)
predictor = None
model_name_global = "kronos-base" # Use key now
API_KEY = os.environ.get("KRONOS_API_KEY")
AVAILABLE_MODELS = {
    'kronos-mini': {
        'name': 'Kronos-mini',
        'model_id': 'NeoQuasar/Kronos-mini',
        'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k',
        'context_length': 2048,
        'params': '4.1M',
        'description': 'Lightweight model, suitable for fast prediction'
    },
    'kronos-small': {
        'name': 'Kronos-small',
        'model_id': 'NeoQuasar/Kronos-small',
        'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
        'context_length': 512,
        'params': '24.7M',
        'description': 'Small model, balanced performance and speed'
    },
    'kronos-base': {
        'name': 'Kronos-base',
        'model_id': 'NeoQuasar/Kronos-base',
        'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
        'context_length': 512,
        'params': '102.3M',
        'description': 'Base model, provides better prediction quality'
    }
}

# --- Helper Functions ---

def download_model_from_hf(model_name, local_dir="."):
    """Downloads a model from Hugging Face Hub."""
    logging.info(f"Downloading model '{model_name}' from Hugging Face Hub...")
    try:
        hf_hub_download(repo_id=model_name, filename="config.json", local_dir=local_dir)
        hf_hub_download(repo_id=model_name, filename="pytorch_model.bin", local_dir=local_dir)
        logging.info("Model downloaded successfully.")
        return True
    except Exception as e:
        logging.error(f"Failed to download model: {e}")
        return False

# --- API Authentication ---

def require_api_key(f):
    """Decorator to protect routes with an API key."""
    @wraps(f)
    def decorated_function(*args, **kwargs):
        # If KRONOS_API_KEY is not set on the server, skip authentication (for local/dev)
        if not API_KEY:
            logging.warning("API key not set. Skipping authentication.")
            return f(*args, **kwargs)
        
        auth_header = request.headers.get('Authorization')
        if not auth_header or not auth_header.startswith('Bearer '):
            return jsonify({'error': 'Authorization header is missing or invalid. Use Bearer token.'}), 401
        
        token = auth_header.split(' ')[1]
        if token != API_KEY:
            return jsonify({'error': 'Invalid API Key.'}), 401
            
        return f(*args, **kwargs)
    return decorated_function

# --- API Endpoints ---

@app.route('/api/load-model', methods=['POST'])
@require_api_key
def load_model_endpoint():
    """Loads the prediction model into memory."""
    global predictor, model_name_global
    
    json_data = request.get_json()
    model_key = json_data.get('model_key', model_name_global) # Changed to model_key
    force_reload = json_data.get('force_reload', False)

    # Validate if the requested model is in the allowed list
    if model_key not in AVAILABLE_MODELS:
        return jsonify({
            'error': f"Invalid model_key. Please choose from the allowed models.",
            'allowed_models': list(AVAILABLE_MODELS.keys())
        }), 400

    if predictor and not force_reload and model_name_global == model_key:
        return jsonify({'status': 'Model already loaded.'})

    try:
        model_config = AVAILABLE_MODELS[model_key]
        model_id = model_config['model_id']
        tokenizer_id = model_config['tokenizer_id']
        
        logging.info(f"Attempting to load model: {model_id}")
        logging.info(f"Attempting to load tokenizer: {tokenizer_id}")

        # Determine device
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        logging.info(f"Using device: {device}")

        # --- Proxy Setup ---
        # Check for proxy settings in environment variables, similar to the webui fix
        proxies = {
            "http": os.environ.get("HTTP_PROXY"),
            "https": os.environ.get("HTTPS_PROXY"),
        }
        # Filter out None values
        proxies = {k: v for k, v in proxies.items() if v}
        if proxies:
            logging.info(f"Using proxies: {proxies}")

        # Load model and tokenizer with proxy support
        model = Kronos.from_pretrained(model_id, proxies=proxies if proxies else None)
        tokenizer = KronosTokenizer.from_pretrained(tokenizer_id, proxies=proxies if proxies else None)
        
        # Create the predictor wrapper
        predictor = KronosPredictor(model, tokenizer, device=device)
        model_name_global = model_key
        
        logging.info(f"Model '{model_config['name']}' loaded successfully.")
        return jsonify({
            'status': f"Model '{model_config['name']}' loaded successfully.",
            'model_info': model_config
        })
    except Exception as e:
        logging.error(f"Error loading model: {e}")
        return jsonify({'error': str(e)}), 500

@app.route('/api/model-status', methods=['GET'])
def model_status():
    """Checks if the model is loaded."""
    if predictor:
        return jsonify({
            'status': 'loaded',
            'model_key': model_name_global,
            'model_info': AVAILABLE_MODELS.get(model_name_global)
        })
    else:
        return jsonify({'status': 'not_loaded'})

@app.route('/api/available-models', methods=['GET'])
def get_available_models():
    """Returns the list of available models and their details."""
    return jsonify(AVAILABLE_MODELS)

@app.route('/api/predict_from_data', methods=['POST'])
@require_api_key
def predict_from_data():
    """
    Receives raw K-line data in the request body, makes a prediction,
    and returns the results.
    """
    if not predictor:
        return jsonify({'error': 'Model not loaded. Please call /api/load-model first.'}), 400

    data = request.get_json()
    if not data or 'k_lines' not in data:
        return jsonify({'error': 'Missing "k_lines" in request body.'}), 400

    k_lines = data['k_lines']
    params = data.get('prediction_params', {})
    pred_len = params.get('pred_len', 120)

    try:
        # Define column names based on standard Binance API format
        # We only need the first 6 columns for the model
        columns = [
            'timestamp', 'open', 'high', 'low', 'close', 'volume',
            'close_time', 'quote_asset_volume', 'number_of_trades',
            'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore'
        ]
        
        # Ensure we only use the first 12 columns if more are provided
        k_lines_standardized = [line[:12] for line in k_lines]

        df = pd.DataFrame(k_lines_standardized, columns=columns)

        # --- Data Type Conversion ---
        df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
        numeric_cols = ['open', 'high', 'low', 'close', 'volume']
        for col in numeric_cols:
            df[col] = pd.to_numeric(df[col])

        # Keep only the necessary columns for the model
        df_model_input = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]

        logging.info(f"Making prediction with pred_len={pred_len} on data with shape {df_model_input.shape}")
        
        # Make prediction
        # --- Timestamp Generation for Predictor ---
        # The predictor requires historical and future timestamps
        x_timestamp = df_model_input['timestamp']
        
        # Assuming the K-line interval is consistent, calculate the interval
        # from the last two points to generate future timestamps.
        if len(x_timestamp) > 1:
            interval = x_timestamp.iloc[-1] - x_timestamp.iloc[-2]
        else:
            # If only one data point, assume a 1-minute interval as a fallback
            interval = pd.Timedelta(minutes=1)
            
        y_timestamp = pd.date_range(
            start=x_timestamp.iloc[-1] + interval,
            periods=pred_len,
            freq=interval
        )
        # Convert DatetimeIndex to Series to prevent '.dt' accessor error inside the model
        y_timestamp = pd.Series(y_timestamp, name='timestamp')

        # Make prediction using the predictor wrapper
        pred_df = predictor.predict(
            df=df_model_input,
            x_timestamp=x_timestamp,
            y_timestamp=y_timestamp,
            pred_len=pred_len,
            verbose=False # Keep logs clean
        )
        
        # Format results for JSON response
        prediction_results = pred_df.to_dict(orient='records')

        return jsonify({
            'success': True,
            'prediction_params': {'pred_len': pred_len},
            'prediction_results': prediction_results
        })

    except Exception as e:
        logging.error(f"Prediction failed: {e}")
        return jsonify({'error': f'An error occurred during prediction: {str(e)}'}), 500