Spaces:
Build error
Build error
| import gradio as gr | |
| import numpy as np | |
| from tensorflow.keras.models import load_model | |
| import pickle | |
| # Load pre-trained models | |
| model_paths = { | |
| "Simple RNN": "model_SimpleRNN.keras", | |
| "GRU": "model_GRU.keras", | |
| "LSTM": "model_LSTM.keras" | |
| } | |
| models = {name: load_model(path) for name, path in model_paths.items()} | |
| # Load ensemble weights from pickle file | |
| with open("ensemble_weights.pkl", "rb") as f: | |
| ensemble_weights = pickle.load(f) | |
| # Preprocessing function | |
| def preprocess_input(user_input_text): | |
| try: | |
| values = list(map(float, user_input_text.replace('\n', '').split(','))) | |
| if len(values) != 35: | |
| return None, "❌ Input must be exactly 35 floats (5 time steps × 7 features)." | |
| input_array = np.array(values).reshape(1, 5, 7) | |
| return input_array, None | |
| except ValueError: | |
| return None, "❌ Invalid input: Ensure all values are valid floats." | |
| # Prediction function | |
| def predict(input_text, selected_model): | |
| input_array, error = preprocess_input(input_text) | |
| if error: | |
| return error | |
| if selected_model == "Ensemble": | |
| preds = [] | |
| for name, model in models.items(): | |
| pred = model.predict(input_array)[0][0] | |
| weight = ensemble_weights.get(name, 0) | |
| preds.append(weight * pred) | |
| prediction = sum(preds) | |
| else: | |
| model = models[selected_model] | |
| prediction = model.predict(input_array)[0][0] | |
| return f"✅ Predicted Probability of Positive Class: {prediction:.4f}" | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Live RNN Model Ensemble Tester") | |
| gr.Markdown("Input 35 comma-separated float values (5 time steps × 7 features)") | |
| input_text = gr.Textbox(label="Input Features (comma-separated, 35 values)", lines=5) | |
| selected_model = gr.Dropdown(choices=list(model_paths.keys()) + ["Ensemble"], value="Simple RNN", label="Select Model") | |
| output = gr.Textbox(label="Output") | |
| submit_btn = gr.Button("Submit") | |
| submit_btn.click(fn=predict, inputs=[input_text, selected_model], outputs=output) | |
| gr.Examples([ | |
| ["0.1,0.2,0.3,0.4,0.5,0.6,0.7,"*5, "Simple RNN"] | |
| ], inputs=[input_text, selected_model]) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch() | |