|
|
import os |
|
|
import sys |
|
|
import logging |
|
|
from functools import wraps |
|
|
from flask import Flask, request, jsonify |
|
|
import torch |
|
|
import pandas as pd |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
try: |
|
|
from model.kronos import Kronos, KronosTokenizer, KronosPredictor |
|
|
except ImportError as e: |
|
|
logging.error(f"Could not import from model.kronos: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
predictor = None |
|
|
model_name_global = "kronos-base" |
|
|
API_KEY = os.environ.get("KRONOS_API_KEY") |
|
|
AVAILABLE_MODELS = { |
|
|
'kronos-mini': { |
|
|
'name': 'Kronos-mini', |
|
|
'model_id': 'NeoQuasar/Kronos-mini', |
|
|
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k', |
|
|
'context_length': 2048, |
|
|
'params': '4.1M', |
|
|
'description': 'Lightweight model, suitable for fast prediction' |
|
|
}, |
|
|
'kronos-small': { |
|
|
'name': 'Kronos-small', |
|
|
'model_id': 'NeoQuasar/Kronos-small', |
|
|
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', |
|
|
'context_length': 512, |
|
|
'params': '24.7M', |
|
|
'description': 'Small model, balanced performance and speed' |
|
|
}, |
|
|
'kronos-base': { |
|
|
'name': 'Kronos-base', |
|
|
'model_id': 'NeoQuasar/Kronos-base', |
|
|
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', |
|
|
'context_length': 512, |
|
|
'params': '102.3M', |
|
|
'description': 'Base model, provides better prediction quality' |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def download_model_from_hf(model_name, local_dir="."): |
|
|
"""Downloads a model from Hugging Face Hub.""" |
|
|
logging.info(f"Downloading model '{model_name}' from Hugging Face Hub...") |
|
|
try: |
|
|
hf_hub_download(repo_id=model_name, filename="config.json", local_dir=local_dir) |
|
|
hf_hub_download(repo_id=model_name, filename="pytorch_model.bin", local_dir=local_dir) |
|
|
logging.info("Model downloaded successfully.") |
|
|
return True |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to download model: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def require_api_key(f): |
|
|
"""Decorator to protect routes with an API key.""" |
|
|
@wraps(f) |
|
|
def decorated_function(*args, **kwargs): |
|
|
|
|
|
if not API_KEY: |
|
|
logging.warning("API key not set. Skipping authentication.") |
|
|
return f(*args, **kwargs) |
|
|
|
|
|
auth_header = request.headers.get('Authorization') |
|
|
if not auth_header or not auth_header.startswith('Bearer '): |
|
|
return jsonify({'error': 'Authorization header is missing or invalid. Use Bearer token.'}), 401 |
|
|
|
|
|
token = auth_header.split(' ')[1] |
|
|
if token != API_KEY: |
|
|
return jsonify({'error': 'Invalid API Key.'}), 401 |
|
|
|
|
|
return f(*args, **kwargs) |
|
|
return decorated_function |
|
|
|
|
|
|
|
|
|
|
|
@app.route('/api/load-model', methods=['POST']) |
|
|
@require_api_key |
|
|
def load_model_endpoint(): |
|
|
"""Loads the prediction model into memory.""" |
|
|
global predictor, model_name_global |
|
|
|
|
|
json_data = request.get_json() |
|
|
model_key = json_data.get('model_key', model_name_global) |
|
|
force_reload = json_data.get('force_reload', False) |
|
|
|
|
|
|
|
|
if model_key not in AVAILABLE_MODELS: |
|
|
return jsonify({ |
|
|
'error': f"Invalid model_key. Please choose from the allowed models.", |
|
|
'allowed_models': list(AVAILABLE_MODELS.keys()) |
|
|
}), 400 |
|
|
|
|
|
if predictor and not force_reload and model_name_global == model_key: |
|
|
return jsonify({'status': 'Model already loaded.'}) |
|
|
|
|
|
try: |
|
|
model_config = AVAILABLE_MODELS[model_key] |
|
|
model_id = model_config['model_id'] |
|
|
tokenizer_id = model_config['tokenizer_id'] |
|
|
|
|
|
logging.info(f"Attempting to load model: {model_id}") |
|
|
logging.info(f"Attempting to load tokenizer: {tokenizer_id}") |
|
|
|
|
|
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
logging.info(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
proxies = { |
|
|
"http": os.environ.get("HTTP_PROXY"), |
|
|
"https": os.environ.get("HTTPS_PROXY"), |
|
|
} |
|
|
|
|
|
proxies = {k: v for k, v in proxies.items() if v} |
|
|
if proxies: |
|
|
logging.info(f"Using proxies: {proxies}") |
|
|
|
|
|
|
|
|
model = Kronos.from_pretrained(model_id, proxies=proxies if proxies else None) |
|
|
tokenizer = KronosTokenizer.from_pretrained(tokenizer_id, proxies=proxies if proxies else None) |
|
|
|
|
|
|
|
|
predictor = KronosPredictor(model, tokenizer, device=device) |
|
|
model_name_global = model_key |
|
|
|
|
|
logging.info(f"Model '{model_config['name']}' loaded successfully.") |
|
|
return jsonify({ |
|
|
'status': f"Model '{model_config['name']}' loaded successfully.", |
|
|
'model_info': model_config |
|
|
}) |
|
|
except Exception as e: |
|
|
logging.error(f"Error loading model: {e}") |
|
|
return jsonify({'error': str(e)}), 500 |
|
|
|
|
|
@app.route('/api/model-status', methods=['GET']) |
|
|
def model_status(): |
|
|
"""Checks if the model is loaded.""" |
|
|
if predictor: |
|
|
return jsonify({ |
|
|
'status': 'loaded', |
|
|
'model_key': model_name_global, |
|
|
'model_info': AVAILABLE_MODELS.get(model_name_global) |
|
|
}) |
|
|
else: |
|
|
return jsonify({'status': 'not_loaded'}) |
|
|
|
|
|
@app.route('/api/available-models', methods=['GET']) |
|
|
def get_available_models(): |
|
|
"""Returns the list of available models and their details.""" |
|
|
return jsonify(AVAILABLE_MODELS) |
|
|
|
|
|
@app.route('/api/predict_from_data', methods=['POST']) |
|
|
@require_api_key |
|
|
def predict_from_data(): |
|
|
""" |
|
|
Receives raw K-line data in the request body, makes a prediction, |
|
|
and returns the results. |
|
|
""" |
|
|
if not predictor: |
|
|
return jsonify({'error': 'Model not loaded. Please call /api/load-model first.'}), 400 |
|
|
|
|
|
data = request.get_json() |
|
|
if not data or 'k_lines' not in data: |
|
|
return jsonify({'error': 'Missing "k_lines" in request body.'}), 400 |
|
|
|
|
|
k_lines = data['k_lines'] |
|
|
params = data.get('prediction_params', {}) |
|
|
pred_len = params.get('pred_len', 120) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
columns = [ |
|
|
'timestamp', 'open', 'high', 'low', 'close', 'volume', |
|
|
'close_time', 'quote_asset_volume', 'number_of_trades', |
|
|
'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore' |
|
|
] |
|
|
|
|
|
|
|
|
k_lines_standardized = [line[:12] for line in k_lines] |
|
|
|
|
|
df = pd.DataFrame(k_lines_standardized, columns=columns) |
|
|
|
|
|
|
|
|
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') |
|
|
numeric_cols = ['open', 'high', 'low', 'close', 'volume'] |
|
|
for col in numeric_cols: |
|
|
df[col] = pd.to_numeric(df[col]) |
|
|
|
|
|
|
|
|
df_model_input = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']] |
|
|
|
|
|
logging.info(f"Making prediction with pred_len={pred_len} on data with shape {df_model_input.shape}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_timestamp = df_model_input['timestamp'] |
|
|
|
|
|
|
|
|
|
|
|
if len(x_timestamp) > 1: |
|
|
interval = x_timestamp.iloc[-1] - x_timestamp.iloc[-2] |
|
|
else: |
|
|
|
|
|
interval = pd.Timedelta(minutes=1) |
|
|
|
|
|
y_timestamp = pd.date_range( |
|
|
start=x_timestamp.iloc[-1] + interval, |
|
|
periods=pred_len, |
|
|
freq=interval |
|
|
) |
|
|
|
|
|
y_timestamp = pd.Series(y_timestamp, name='timestamp') |
|
|
|
|
|
|
|
|
pred_df = predictor.predict( |
|
|
df=df_model_input, |
|
|
x_timestamp=x_timestamp, |
|
|
y_timestamp=y_timestamp, |
|
|
pred_len=pred_len, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
|
|
|
prediction_results = pred_df.to_dict(orient='records') |
|
|
|
|
|
return jsonify({ |
|
|
'success': True, |
|
|
'prediction_params': {'pred_len': pred_len}, |
|
|
'prediction_results': prediction_results |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Prediction failed: {e}") |
|
|
return jsonify({'error': f'An error occurred during prediction: {str(e)}'}), 500 |
|
|
|