Stock_Res / app.py
harshitmahour360's picture
Update app.py
5b14964 verified
import gradio as gr
import numpy as np
from tensorflow.keras.models import load_model
import pickle
# Load pre-trained models
model_paths = {
"Simple RNN": "model_SimpleRNN.keras",
"GRU": "model_GRU.keras",
"LSTM": "model_LSTM.keras"
}
models = {name: load_model(path) for name, path in model_paths.items()}
# Load ensemble weights from pickle file
with open("ensemble_weights.pkl", "rb") as f:
ensemble_weights = pickle.load(f)
# Preprocessing function
def preprocess_input(user_input_text):
try:
values = list(map(float, user_input_text.replace('\n', '').split(',')))
if len(values) != 35:
return None, "❌ Input must be exactly 35 floats (5 time steps × 7 features)."
input_array = np.array(values).reshape(1, 5, 7)
return input_array, None
except ValueError:
return None, "❌ Invalid input: Ensure all values are valid floats."
# Prediction function
def predict(input_text, selected_model):
input_array, error = preprocess_input(input_text)
if error:
return error
if selected_model == "Ensemble":
preds = []
for name, model in models.items():
pred = model.predict(input_array)[0][0]
weight = ensemble_weights.get(name, 0)
preds.append(weight * pred)
prediction = sum(preds)
else:
model = models[selected_model]
prediction = model.predict(input_array)[0][0]
return f"✅ Predicted Probability of Positive Class: {prediction:.4f}"
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Live RNN Model Ensemble Tester")
gr.Markdown("Input 35 comma-separated float values (5 time steps × 7 features)")
input_text = gr.Textbox(label="Input Features (comma-separated, 35 values)", lines=5)
selected_model = gr.Dropdown(choices=list(model_paths.keys()) + ["Ensemble"], value="Simple RNN", label="Select Model")
output = gr.Textbox(label="Output")
submit_btn = gr.Button("Submit")
submit_btn.click(fn=predict, inputs=[input_text, selected_model], outputs=output)
gr.Examples([
["0.1,0.2,0.3,0.4,0.5,0.6,0.7,"*5, "Simple RNN"]
], inputs=[input_text, selected_model])
# Launch
if __name__ == "__main__":
demo.launch()