agent1 / TrendAnalysisAgent /tools /DataAnalysisTool.py
varun324242's picture
Upload folder using huggingface_hub
06ba83e verified
from agency_swarm.tools import BaseTool
from pydantic import Field, BaseModel
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
class DataAnalysisTool(BaseTool):
"""
This tool processes and analyzes market data using pandas and scikit-learn.
It performs data cleaning, feature extraction, and trend detection using a linear regression model.
"""
data: pd.DataFrame = Field(
..., description="The market data as a pandas DataFrame."
)
class Config:
arbitrary_types_allowed = True
def run(self):
"""
Processes the market data, performs feature extraction, and detects trends using a linear regression model.
Returns the model's mean squared error and a DataFrame with predictions.
"""
# Data Cleaning: Drop any rows with missing values
cleaned_data = self.data.dropna()
# Feature Extraction: Use 'Open', 'High', 'Low', 'Volume' as features
features = cleaned_data[['Open', 'High', 'Low', 'Volume']]
target = cleaned_data['Close']
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)
# Standardize the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Train a Linear Regression model
model = LinearRegression()
model.fit(X_train_scaled, y_train)
# Make predictions
predictions = model.predict(X_test_scaled)
# Calculate the mean squared error
mse = mean_squared_error(y_test, predictions)
# Create a DataFrame with actual and predicted values
results_df = pd.DataFrame({'Actual': y_test, 'Predicted': predictions}, index=y_test.index)
# Return the mean squared error and the results DataFrame
return mse, results_df
# Example usage:
# Assume 'market_data' is a pandas DataFrame with columns ['Open', 'High', 'Low', 'Close', 'Volume']
# tool = DataAnalysisTool(data=market_data)
# mse, results = tool.run()
# print("Mean Squared Error:", mse)
# print(results)