crypto-predict / app.py
shaheerawan3's picture
Update app.py
9fa8558 verified
import os
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import pandas as pd
import yfinance as yf
import streamlit as st
from sklearn.preprocessing import MinMaxScaler
from datetime import datetime, timedelta
import joblib
import warnings
import ta
from tqdm import tqdm
warnings.filterwarnings('ignore')
class PriceScaler:
def __init__(self):
self.scaler = MinMaxScaler()
def fit_transform(self, data):
data_2d = np.array(data).reshape(-1, 1)
return self.scaler.fit_transform(data_2d).flatten()
def inverse_transform(self, data):
data_2d = np.array(data).reshape(-1, 1)
return self.scaler.inverse_transform(data_2d).flatten()
class CryptoPredictor(nn.Module):
def __init__(self, input_dim, hidden_dim=128, num_layers=2, dropout=0.2):
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.lstm = nn.LSTM(
input_dim, hidden_dim, num_layers=num_layers, batch_first=True,
dropout=dropout if num_layers > 1 else 0, bidirectional=True
)
self.bn = nn.BatchNorm1d(hidden_dim * 2)
self.fc = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.confidence_fc = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, x):
batch_size = x.size(0)
h0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim).to(x.device)
c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim).to(x.device)
lstm_out, _ = self.lstm(x, (h0, c0))
last_hidden = lstm_out[:, -1, :]
normalized_hidden = self.bn(last_hidden)
prediction = self.fc(normalized_hidden)
confidence = self.confidence_fc(normalized_hidden)
return prediction, confidence
class CryptoAnalyzer:
def __init__(self, model_dir="models", cache_dir="cache"):
self.scaler = MinMaxScaler()
self.price_scaler = PriceScaler()
self.model_dir = model_dir
self.cache_dir = cache_dir
os.makedirs(model_dir, exist_ok=True)
os.makedirs(cache_dir, exist_ok=True)
self.feature_columns = [
'Open', 'High', 'Low', 'Close', 'Volume', 'Returns', 'Volatility',
'MA5', 'MA20', 'RSI', 'Price_Momentum', 'Volume_Momentum', 'MACD',
'BB_upper', 'BB_lower', 'Stoch_K', 'Stoch_D', 'ADX', 'ATR'
]
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_data(self, symbol, days):
end_date = datetime.now()
start_date = end_date - timedelta(days=days + 30)
df = yf.download(f"{symbol}-USD", start=start_date, end=end_date, progress=False)
if df.empty:
raise ValueError(f"No data available for {symbol}")
df['Returns'] = df['Close'].pct_change()
df['Volatility'] = df['Returns'].rolling(window=20).std()
df['MA5'] = df['Close'].rolling(window=5).mean()
df['MA20'] = df['Close'].rolling(window=20).mean()
df['RSI'] = ta.momentum.rsi(df['Close'])
df['Price_Momentum'] = ta.momentum.roc(df['Close'])
df['Volume_Momentum'] = ta.momentum.roc(df['Volume'])
macd = ta.trend.macd(df['Close'])
df['MACD'] = macd.iloc[:, 0]
bollinger = ta.volatility.BollingerBands(df['Close'])
df['BB_upper'] = bollinger.bollinger_hband()
df['BB_lower'] = bollinger.bollinger_lband()
stoch = ta.momentum.StochasticOscillator(df['High'], df['Low'], df['Close'])
df['Stoch_K'] = stoch.stoch()
df['Stoch_D'] = stoch.stoch_signal()
df['ADX'] = ta.trend.adx(df['High'], df['Low'], df['Close'])
df['ATR'] = ta.volatility.average_true_range(df['High'], df['Low'], df['Close'])
df = df.dropna()
return df.iloc[-days:]
def prepare_data(self, df, lookback):
features = df[self.feature_columns].values
scaled_features = self.scaler.fit_transform(features)
close_prices = df['Close'].values
scaled_close = self.price_scaler.fit_transform(close_prices)
X, y = [], []
for i in range(len(df) - lookback):
X.append(scaled_features[i:(i + lookback)])
y.append(scaled_close[i + lookback])
X = torch.FloatTensor(np.array(X)).to(self.device)
y = torch.FloatTensor(np.array(y)).reshape(-1).to(self.device)
return X, y
def get_model_path(self, symbol):
return os.path.join(self.model_dir, f"{symbol.lower()}_model.pth")
def get_scaler_path(self, symbol):
return os.path.join(self.model_dir, f"{symbol.lower()}_scaler.pkl")
def train_model(self, X, y, symbol):
model = CryptoPredictor(X.shape[2]).to(self.device)
criterion = nn.HuberLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
batch_size = min(32, len(X) // 4)
dataset = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
best_loss = float('inf')
patience = 10
patience_counter = 0
model.train()
with tqdm(range(50), desc=f"Training {symbol} model") as pbar:
for epoch in pbar:
total_loss = 0
for batch_X, batch_y in train_loader:
optimizer.zero_grad()
predictions, _ = model(batch_X)
loss = criterion(predictions, batch_y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
scheduler.step(avg_loss)
pbar.set_postfix({'loss': f'{avg_loss:.6f}'})
if avg_loss < best_loss:
best_loss = avg_loss
patience_counter = 0
torch.save(model.state_dict(), self.get_model_path(symbol))
else:
patience_counter += 1
if patience_counter >= patience:
break
return model
def get_predictions(self, symbol, days, lookback):
try:
df = self.get_data(symbol, days)
X, y = self.prepare_data(df, lookback)
model_path = self.get_model_path(symbol)
if os.path.exists(model_path):
model = CryptoPredictor(X.shape[2]).to(self.device)
model.load_state_dict(torch.load(model_path))
else:
model = self.train_model(X, y, symbol)
joblib.dump(self.scaler, self.get_scaler_path(symbol))
model.eval()
with torch.no_grad():
predictions, confidence = model(X)
predictions_reshaped = predictions.cpu().numpy().reshape(-1, 1)
predictions = self.price_scaler.inverse_transform(predictions_reshaped).flatten()
y_np_reshaped = y.cpu().numpy().reshape(-1, 1)
actual_prices = self.price_scaler.inverse_transform(y_np_reshaped).flatten()
rmse = float(np.sqrt(np.mean((actual_prices - predictions) ** 2)))
mape = float(np.mean(np.abs((actual_prices - predictions) / actual_prices)) * 100)
r2 = float(1 - np.sum((actual_prices - predictions) ** 2) / np.sum((actual_prices - actual_prices.mean()) ** 2))
dates = df.index[lookback:].strftime('%Y-%m-%d').tolist()
return {
'dates': dates,
'actual': actual_prices.tolist(),
'predicted': predictions.tolist(),
'confidence': confidence.cpu().numpy().flatten().tolist(),
'rmse': rmse,
'mape': mape,
'r2': r2,
'volatility': float(df['Volatility'].mean() * 100),
'current_price': float(df['Close'].iloc[-1]),
'volume': float(df['Volume'].iloc[-1]),
'rsi': float(df['RSI'].iloc[-1]),
'macd': float(df['MACD'].iloc[-1])
}
except Exception as e:
raise ValueError(f"Prediction failed: {str(e)}")
def main():
st.title("πŸš€ Cryptocurrency Price Prediction")
st.sidebar.header("Settings")
symbol = st.sidebar.selectbox("Select Cryptocurrency", ["BTC", "ETH", "BNB", "XRP", "ADA", "SOL", "DOT", "DOGE"], index=0)
custom_symbol = st.sidebar.text_input("Or enter custom symbol (e.g., MATIC)", "")
days = st.sidebar.slider("Historical Days", 30, 365, 180)
lookback = st.sidebar.slider("Lookback Period (Days)", 7, 60, 30)
symbol = custom_symbol if custom_symbol else symbol
if st.sidebar.button("πŸ“Š Generate Analysis"):
analyzer = CryptoAnalyzer()
try:
st.info("Fetching data and generating predictions...")
predictions = analyzer.get_predictions(symbol, days, lookback)
# Display results
st.subheader("πŸ“ˆ Price Prediction Results")
st.line_chart({
"Actual Prices": predictions['actual'],
"Predicted Prices": predictions['predicted']
})
st.subheader("πŸ“Š Model Metrics")
st.write(f"**RΒ² Score:** {predictions['r2']:.4f}")
st.write(f"**RMSE:** ${predictions['rmse']:.2f}")
st.write(f"**MAPE:** {predictions['mape']:.2f}%")
st.subheader("πŸ” Additional Indicators")
st.write(f"**RSI:** {predictions['rsi']:.2f}")
st.write(f"**MACD:** {predictions['macd']:.2f}")
st.write(f"**Volatility:** {predictions['volatility']:.2f}%")
except Exception as e:
st.error(f"⚠️ Error: {str(e)}")
if __name__ == "__main__":
main()