Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import matplotlib.pyplot as plt | |
| from dotenv import load_dotenv | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain.agents import Tool, initialize_agent | |
| from langchain.agents.agent_types import AgentType | |
| # === Load environment === | |
| load_dotenv() | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash") | |
| # === Utility functions === | |
| def parse_yearly_data(text): | |
| pattern = r"(\d{4})[^0-9]{1,10}?([\d.,]+)[\s]?[BbMm]?" | |
| matches = re.findall(pattern, text) | |
| data = {} | |
| for year, value in matches: | |
| value = float(value.replace(',', '')) | |
| data[int(year)] = value | |
| return dict(sorted(data.items())) | |
| def forecast_next_value(data: dict): | |
| years = list(data.keys()) | |
| values = list(data.values()) | |
| if len(years) < 2: | |
| raise ValueError("Need at least 2 years of data to forecast") | |
| delta = values[-1] - values[-2] | |
| next_year = years[-1] + 1 | |
| next_value = values[-1] + delta | |
| data[next_year] = next_value | |
| return data | |
| def plot_forecast(data: dict, title="Financial Forecast"): | |
| years = list(data.keys()) | |
| values = list(data.values()) | |
| plt.figure(figsize=(8, 5)) | |
| plt.plot(years, values, marker='o', linestyle='-') | |
| plt.title(title) | |
| plt.xlabel("Year") | |
| plt.ylabel("Value") | |
| plt.grid(True) | |
| plt.tight_layout() | |
| plt.savefig("forecast.png") | |
| plt.close() | |
| # === Tool function === | |
| def forecast_tool_func(query: str) -> str: | |
| data = parse_yearly_data(query) | |
| forecasted = forecast_next_value(data) | |
| plot_forecast(forecasted, title="Forecast based on financial data") | |
| return f"📈 Forecast complete. Predicted value for {max(forecasted.keys())}: {forecasted[max(forecasted.keys())]:,.2f}. Plot saved as forecast.png." | |
| # === LangChain Tool === | |
| forecast_tool = Tool( | |
| name="FinancialForecaster", | |
| func=forecast_tool_func, | |
| description="Use this tool to forecast financial metrics based on yearly values given in the user's query. It parses numerical data and extrapolates one year forward, also saving a plot." | |
| ) | |
| # === LangChain Agent === | |
| dc_agent = initialize_agent( | |
| tools=[forecast_tool], | |
| llm=llm, | |
| agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True | |
| ) | |
| # === Entry point === | |
| if __name__ == "__main__": | |
| query = "Apple's revenue was 260B in 2020, 320B in 2021, 350B in 2022, and 400B in 2023" | |
| result = dc_agent.run(query) | |
| print("\n Agent Response:\n", result) | |