Spaces:
Sleeping
Sleeping
app changes
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
|
@@ -11,6 +11,8 @@ import json
|
|
| 11 |
import math
|
| 12 |
from typing import Union
|
| 13 |
from deployment.config import load_model_config, get_input_size
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# --- Helper function to get model device ---
|
| 16 |
def get_model_device(model):
|
|
@@ -669,9 +671,6 @@ class xLSTMPredictor(nn.Module):
|
|
| 669 |
|
| 670 |
return predictions, states
|
| 671 |
|
| 672 |
-
# --- FastAPI App ---
|
| 673 |
-
app = FastAPI()
|
| 674 |
-
|
| 675 |
# --- Load Models ---
|
| 676 |
MODELS_DIR = "deployment/models"
|
| 677 |
models = {}
|
|
@@ -737,45 +736,122 @@ with open(os.path.join(MODELS_DIR, "RandomForest_model.pkl"), "rb") as f:
|
|
| 737 |
models["random_forest"] = rf_model
|
| 738 |
|
| 739 |
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
features: list
|
| 744 |
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
@app.post("/predict", response_model=InferenceResponse)
|
| 749 |
-
async def predict(request: InferenceRequest):
|
| 750 |
-
model = models.get(request.model_name)
|
| 751 |
if not model:
|
| 752 |
-
return
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
with torch.no_grad():
|
| 765 |
-
|
| 766 |
-
|
| 767 |
else: # scikit-learn models
|
| 768 |
-
#
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
#
|
| 772 |
-
#
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
|
| 777 |
-
|
|
|
|
|
|
|
|
|
|
| 778 |
|
| 779 |
@app.get("/")
|
| 780 |
def read_root():
|
| 781 |
-
return
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
|
|
|
| 11 |
import math
|
| 12 |
from typing import Union
|
| 13 |
from deployment.config import load_model_config, get_input_size
|
| 14 |
+
from fastapi import FastAPI, Mount
|
| 15 |
+
from gradio.themes.base import Base
|
| 16 |
|
| 17 |
# --- Helper function to get model device ---
|
| 18 |
def get_model_device(model):
|
|
|
|
| 671 |
|
| 672 |
return predictions, states
|
| 673 |
|
|
|
|
|
|
|
|
|
|
| 674 |
# --- Load Models ---
|
| 675 |
MODELS_DIR = "deployment/models"
|
| 676 |
models = {}
|
|
|
|
| 736 |
models["random_forest"] = rf_model
|
| 737 |
|
| 738 |
|
| 739 |
+
from sklearn.preprocessing import StandardScaler
|
| 740 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 741 |
+
import matplotlib.pyplot as plt
|
|
|
|
| 742 |
|
| 743 |
+
def predict(model_name, file):
|
| 744 |
+
model = models.get(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 745 |
if not model:
|
| 746 |
+
return "Model not found", None, None
|
| 747 |
+
|
| 748 |
+
df = pd.read_csv(file.name)
|
| 749 |
+
|
| 750 |
+
config = load_model_config(model_name, models_dir="deployment/models")
|
| 751 |
+
feature_cols = config["feature_cols"]
|
| 752 |
+
target_col = config["target_col"]
|
| 753 |
+
seq_length = config["seq_length"]
|
| 754 |
+
|
| 755 |
+
# Data preparation (assuming the uploaded file is the test set)
|
| 756 |
+
scaler = StandardScaler()
|
| 757 |
+
# Fit on a dummy array to avoid errors, in a real scenario you would load a fitted scaler
|
| 758 |
+
scaler.fit(np.random.rand(100, len(feature_cols)))
|
| 759 |
+
features = scaler.transform(df[feature_cols].values)
|
| 760 |
+
targets = df[target_col].values
|
| 761 |
+
|
| 762 |
+
X_test = []
|
| 763 |
+
y_test = []
|
| 764 |
+
|
| 765 |
+
for i in range(len(features) - seq_length):
|
| 766 |
+
X_test.append(features[i : i + seq_length])
|
| 767 |
+
y_test.append(targets[i : i + seq_length])
|
| 768 |
+
|
| 769 |
+
X_test = torch.FloatTensor(np.array(X_test))
|
| 770 |
+
y_test = np.array(y_test)
|
| 771 |
+
|
| 772 |
+
# Prediction
|
| 773 |
+
if model_name in ["hawk", "mamba", "xlstm"]:
|
| 774 |
+
X_test = X_test.to(device)
|
| 775 |
with torch.no_grad():
|
| 776 |
+
predictions, _ = model(X_test)
|
| 777 |
+
predictions = predictions.cpu().numpy()
|
| 778 |
else: # scikit-learn models
|
| 779 |
+
# For sklearn models, you might need to flatten the sequences
|
| 780 |
+
X_test_reshaped = X_test.reshape(len(X_test), -1)
|
| 781 |
+
predictions = model.predict(X_test_reshaped)
|
| 782 |
+
# The output shape of sklearn models might differ, you might need to adjust this
|
| 783 |
+
# For this example, let's assume it's a 1D array and we need to make it match the y_test shape
|
| 784 |
+
predictions = np.repeat(predictions[:, np.newaxis], y_test.shape[1], axis=1)
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
# For PyTorch models, predictions have an extra dimension
|
| 788 |
+
if model_name in ["hawk", "mamba", "xlstm"]:
|
| 789 |
+
y_pred_for_metrics = predictions[:, -1, 0]
|
| 790 |
+
else:
|
| 791 |
+
y_pred_for_metrics = predictions[:, -1]
|
| 792 |
+
|
| 793 |
+
# Calculate metrics
|
| 794 |
+
y_true_for_metrics = y_test[:, -1]
|
| 795 |
+
metrics = {
|
| 796 |
+
"MSE": mean_squared_error(y_true_for_metrics, y_pred_for_metrics),
|
| 797 |
+
"RMSE": np.sqrt(mean_squared_error(y_true_for_metrics, y_pred_for_metrics)),
|
| 798 |
+
"MAE": mean_absolute_error(y_true_for_metrics, y_pred_for_metrics),
|
| 799 |
+
"R2": r2_score(y_true_for_metrics, y_pred_for_metrics),
|
| 800 |
+
}
|
| 801 |
+
metrics_str = json.dumps(metrics, indent=4)
|
| 802 |
+
|
| 803 |
+
# Create plot
|
| 804 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 805 |
+
ax.plot(y_true_for_metrics, label="Actual")
|
| 806 |
+
ax.plot(y_pred_for_metrics, label="Predicted")
|
| 807 |
+
ax.set_title("Predictions vs Actual")
|
| 808 |
+
ax.set_xlabel("Time Step")
|
| 809 |
+
ax.set_ylabel("Value")
|
| 810 |
+
ax.legend()
|
| 811 |
+
ax.grid(True)
|
| 812 |
+
|
| 813 |
+
# For this example, we'll just return the last prediction of the last sequence
|
| 814 |
+
last_prediction = predictions[-1, -1, 0] if model_name in ["hawk", "mamba", "xlstm"] else predictions[-1, -1]
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
return f"{last_prediction:.4f}", metrics_str, fig
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
# --- Gradio Interface ---
|
| 821 |
+
with gr.Blocks(theme=Base(), title="Stock Predictor") as demo:
|
| 822 |
+
gr.Markdown(
|
| 823 |
+
"""
|
| 824 |
+
# Stock Price Predictor
|
| 825 |
+
Select a model and upload a CSV file with the required features to get a prediction.
|
| 826 |
+
"""
|
| 827 |
+
)
|
| 828 |
+
with gr.Row():
|
| 829 |
+
with gr.Column():
|
| 830 |
+
model_name = gr.Dropdown(
|
| 831 |
+
label="Select Model", choices=list(models.keys())
|
| 832 |
+
)
|
| 833 |
+
feature_input = gr.File(
|
| 834 |
+
label="Upload CSV with features",
|
| 835 |
+
)
|
| 836 |
+
predict_btn = gr.Button("Predict")
|
| 837 |
+
with gr.Column():
|
| 838 |
+
prediction_output = gr.Textbox(label="Prediction")
|
| 839 |
+
metrics_output = gr.Textbox(label="Metrics")
|
| 840 |
+
plot_output = gr.Plot(label="Plots")
|
| 841 |
+
|
| 842 |
+
predict_btn.click(
|
| 843 |
+
fn=predict,
|
| 844 |
+
inputs=[model_name, feature_input],
|
| 845 |
+
outputs=[prediction_output, metrics_output, plot_output],
|
| 846 |
+
)
|
| 847 |
|
| 848 |
+
# --- FastAPI App ---
|
| 849 |
+
app = FastAPI()
|
| 850 |
+
|
| 851 |
+
from fastapi.responses import RedirectResponse
|
| 852 |
|
| 853 |
@app.get("/")
|
| 854 |
def read_root():
|
| 855 |
+
return RedirectResponse(url="/gradio")
|
| 856 |
+
|
| 857 |
+
app = gr.mount_gradio_app(app, demo, path="/gradio")
|