Ade1ola commited on
Commit
dacd9af
·
verified ·
1 Parent(s): 643dbf2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import tensorflow as tf
5
+ import joblib
6
+ from tensorflow.keras.models import load_model
7
+
8
+ # Load the trained LSTM model
9
+ model = load_model("lstm_gru_model5.h5")
10
+
11
+ # Load the MinMaxScaler
12
+ scaler = joblib.load("scaler.pkl")
13
+
14
+ def preprocess_input(data):
15
+ """Preprocess input data for LSTM model."""
16
+ scaled_data = scaler.transform(np.array(data).reshape(-1, 1))
17
+ return np.array(scaled_data).reshape(1, len(data), 1) # Reshape for LSTM
18
+
19
+ def predict_forex(prices):
20
+ """Predict the next forex price based on the input sequence."""
21
+ try:
22
+ input_data = [float(price) for price in prices.split(",")]
23
+ if len(input_data) < 10: # Ensure enough input data
24
+ return "Please provide at least 10 previous forex prices."
25
+
26
+ preprocessed_data = preprocess_input(input_data[-10:]) # Use last 10 prices
27
+ prediction = model.predict(preprocessed_data)
28
+ predicted_price = scaler.inverse_transform(prediction)[0][0] # Convert back to original scale
29
+ return f"Predicted Next Price: {predicted_price:.5f}"
30
+ except Exception as e:
31
+ return f"Error: {str(e)}"
32
+
33
+ def batch_predict(file):
34
+ """Batch prediction for CSV files."""
35
+ try:
36
+ df = pd.read_csv(file)
37
+ if "prices" not in df.columns:
38
+ return "CSV must have a 'prices' column with historical data."
39
+
40
+ df["predictions"] = df["prices"].rolling(window=10).apply(lambda x: predict_forex(",".join(map(str, x))) if len(x) == 10 else None)
41
+ return df.dropna()
42
+ except Exception as e:
43
+ return f"Error: {str(e)}"
44
+
45
+ # Gradio UI
46
+ demo = gr.Interface(
47
+ fn=predict_forex,
48
+ inputs=gr.Textbox(label="Enter last 10 forex prices (comma-separated)"),
49
+ outputs=gr.Textbox(label="Predicted Next Price"),
50
+ title="Forex Price Predictor",
51
+ description="Enter the last 10 forex prices to predict the next price. Upload CSV for batch predictions.",
52
+ examples=[
53
+ ["1.2345,1.2350,1.2360,1.2370,1.2380,1.2390,1.2400,1.2410,1.2420,1.2430"]
54
+ ],
55
+ allow_flagging="never"
56
+ )
57
+
58
+ batch_demo = gr.Interface(
59
+ fn=batch_predict,
60
+ inputs=gr.File(label="Upload CSV"),
61
+ outputs=gr.Dataframe(label="Predictions"),
62
+ title="Batch Prediction",
63
+ description="Upload a CSV with a 'prices' column for batch predictions."
64
+ )
65
+
66
+ gr.TabbedInterface([demo, batch_demo], ["Single Prediction", "Batch Prediction"]).launch()