kronos / app.py
yangyang158's picture
feat: Integrate UI control panel and fix CORS issue
8334f8b
import os
import sys
import logging
from functools import wraps
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
import torch
import pandas as pd
from huggingface_hub import hf_hub_download
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Add parent directory to sys.path to allow imports from 'model'
try:
from model.kronos import Kronos, KronosTokenizer, KronosPredictor
except ImportError as e:
logging.error(f"Could not import from model.kronos: {e}")
sys.exit(1)
# --- Globals ---
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
predictor = None
model_name_global = "kronos-base" # Use key now
API_KEY = os.environ.get("KRONOS_API_KEY")
AVAILABLE_MODELS = {
'kronos-mini': {
'name': 'Kronos-mini',
'model_id': 'NeoQuasar/Kronos-mini',
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k',
'context_length': 2048,
'params': '4.1M',
'description': 'Lightweight model, suitable for fast prediction'
},
'kronos-small': {
'name': 'Kronos-small',
'model_id': 'NeoQuasar/Kronos-small',
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
'context_length': 512,
'params': '24.7M',
'description': 'Small model, balanced performance and speed'
},
'kronos-base': {
'name': 'Kronos-base',
'model_id': 'NeoQuasar/Kronos-base',
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
'context_length': 512,
'params': '102.3M',
'description': 'Base model, provides better prediction quality'
}
}
# --- Helper Functions ---
def download_model_from_hf(model_name, local_dir="."):
"""Downloads a model from Hugging Face Hub."""
logging.info(f"Downloading model '{model_name}' from Hugging Face Hub...")
try:
hf_hub_download(repo_id=model_name, filename="config.json", local_dir=local_dir)
hf_hub_download(repo_id=model_name, filename="pytorch_model.bin", local_dir=local_dir)
logging.info("Model downloaded successfully.")
return True
except Exception as e:
logging.error(f"Failed to download model: {e}")
return False
# --- API Authentication ---
def require_api_key(f):
"""Decorator to protect routes with an API key."""
@wraps(f)
def decorated_function(*args, **kwargs):
# If KRONOS_API_KEY is not set on the server, skip authentication (for local/dev)
if not API_KEY:
logging.warning("API key not set. Skipping authentication.")
return f(*args, **kwargs)
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({'error': 'Authorization header is missing or invalid. Use Bearer token.'}), 401
token = auth_header.split(' ')[1]
if token != API_KEY:
return jsonify({'error': 'Invalid API Key.'}), 401
return f(*args, **kwargs)
return decorated_function
# --- API Endpoints ---
@app.route('/')
def index():
"""Serves the index.html file for the visualizer."""
return send_from_directory('.', 'index.html')
@app.route('/api/load-model', methods=['POST'])
@require_api_key
def load_model_endpoint():
"""Loads the prediction model into memory."""
global predictor, model_name_global
json_data = request.get_json()
model_key = json_data.get('model_key', model_name_global) # Changed to model_key
force_reload = json_data.get('force_reload', False)
# Validate if the requested model is in the allowed list
if model_key not in AVAILABLE_MODELS:
return jsonify({
'error': f"Invalid model_key. Please choose from the allowed models.",
'allowed_models': list(AVAILABLE_MODELS.keys())
}), 400
if predictor and not force_reload and model_name_global == model_key:
return jsonify({'status': 'Model already loaded.'})
try:
model_config = AVAILABLE_MODELS[model_key]
model_id = model_config['model_id']
tokenizer_id = model_config['tokenizer_id']
logging.info(f"Attempting to load model: {model_id}")
logging.info(f"Attempting to load tokenizer: {tokenizer_id}")
# Determine device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {device}")
# --- Proxy Setup ---
# Check for proxy settings in environment variables, similar to the webui fix
proxies = {
"http": os.environ.get("HTTP_PROXY"),
"https": os.environ.get("HTTPS_PROXY"),
}
# Filter out None values
proxies = {k: v for k, v in proxies.items() if v}
if proxies:
logging.info(f"Using proxies: {proxies}")
# Load model and tokenizer with proxy support
model = Kronos.from_pretrained(model_id, proxies=proxies if proxies else None)
tokenizer = KronosTokenizer.from_pretrained(tokenizer_id, proxies=proxies if proxies else None)
# Create the predictor wrapper
predictor = KronosPredictor(model, tokenizer, device=device)
model_name_global = model_key
logging.info(f"Model '{model_config['name']}' loaded successfully.")
return jsonify({
'status': f"Model '{model_config['name']}' loaded successfully.",
'model_info': model_config
})
except Exception as e:
logging.error(f"Error loading model: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/api/model-status', methods=['GET'])
def model_status():
"""Checks if the model is loaded."""
if predictor:
return jsonify({
'status': 'loaded',
'model_key': model_name_global,
'model_info': AVAILABLE_MODELS.get(model_name_global)
})
else:
return jsonify({'status': 'not_loaded'})
@app.route('/api/available-models', methods=['GET'])
def get_available_models():
"""Returns the list of available models and their details."""
return jsonify(AVAILABLE_MODELS)
@app.route('/api/predict', methods=['POST'])
@require_api_key
def predict():
"""
Receives raw K-line data in the request body, makes a prediction,
and returns the results.
"""
if not predictor:
return jsonify({'error': 'Model not loaded. Please call /api/load-model first.'}), 400
data = request.get_json(force=True)
if not data or 'k_lines' not in data:
return jsonify({'error': 'Missing "k_lines" in request body.'}), 400
k_lines = data['k_lines']
params = data.get('prediction_params', {})
pred_len = params.get('pred_len', 120)
try:
# Define column names based on standard Binance API format
# We only need the first 6 columns for the model
columns = [
'timestamp', 'open', 'high', 'low', 'close', 'volume',
'close_time', 'quote_asset_volume', 'number_of_trades',
'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore'
]
# Ensure we only use the first 12 columns if more are provided
k_lines_standardized = [line[:12] for line in k_lines]
df = pd.DataFrame(k_lines_standardized, columns=columns)
# --- Data Type Conversion ---
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
numeric_cols = ['open', 'high', 'low', 'close', 'volume']
for col in numeric_cols:
df[col] = pd.to_numeric(df[col])
# Keep only the necessary columns for the model
df_model_input = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
logging.info(f"Making prediction with pred_len={pred_len} on data with shape {df_model_input.shape}")
# Make prediction
# --- Timestamp Generation for Predictor ---
# The predictor requires historical and future timestamps
x_timestamp = df_model_input['timestamp']
# Assuming the K-line interval is consistent, calculate the interval
# from the last two points to generate future timestamps.
if len(x_timestamp) > 1:
interval = x_timestamp.iloc[-1] - x_timestamp.iloc[-2]
else:
# If only one data point, assume a 1-minute interval as a fallback
interval = pd.Timedelta(minutes=1)
y_timestamp = pd.date_range(
start=x_timestamp.iloc[-1] + interval,
periods=pred_len,
freq=interval
)
# Convert DatetimeIndex to Series to prevent '.dt' accessor error inside the model
y_timestamp = pd.Series(y_timestamp, name='timestamp')
# Make prediction using the predictor wrapper
pred_df = predictor.predict(
df=df_model_input,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
verbose=False # Keep logs clean
)
# Format results for JSON response
# --- Format results to match input format ---
pred_df_reset = pred_df.reset_index()
# Convert timestamp to Unix milliseconds integer
pred_df_reset['timestamp'] = (pred_df_reset['timestamp'].astype('int64') / 10**6).astype('int64')
# Reorder columns to match the desired output format: [timestamp, open, high, low, close, volume]
output_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
pred_df_formatted = pred_df_reset[output_columns]
# Convert to list of lists
prediction_results = pred_df_formatted.values.tolist()
return jsonify({
'success': True,
'prediction_params': {'pred_len': pred_len},
'prediction_results': prediction_results
})
except Exception as e:
logging.error(f"Prediction failed: {e}")
return jsonify({'error': f'An error occurred during prediction: {str(e)}'}), 500
if __name__ == '__main__':
# This block is for local debugging purposes.
# The production server will use a WSGI server like Gunicorn.
app.run(host='0.0.0.0', port=7860, debug=True)