Spaces:
Sleeping
Sleeping
| from FinancialAgent import FinancialAgentFactory | |
| from prompt import REFINERY_PROMPT | |
| from models import ResponseState | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import matplotlib.pyplot as plt | |
| from langchain_community.vectorstores import FAISS | |
| from OllamaAPIClient import OllamaAPIClient | |
| class FinancialAgentOllama(FinancialAgentFactory): | |
| """Concrete Financial Agent using Ollama.""" | |
| def __init__(self, st, model_name="qwen3:4b", url="https://mrfirdauss-ollama-api.hf.space", embedding=None): | |
| super().__init__(st, model_name) | |
| self.client = OllamaAPIClient(url) | |
| self.vector_db = FAISS.load_local("vs_68bf713eea2c81919ac08298a05d6704", embedding, allow_dangerous_deserialization=True) | |
| def __stream_answer__(self, instructions, input_messages): | |
| response_stream = self.client.chat( | |
| model=self.model_name, | |
| messages=input_messages + [{"role": "user", "content": instructions}], | |
| stream=True | |
| ) | |
| for chunk in response_stream: | |
| if "message" in chunk and "content" in chunk["message"]: | |
| yield chunk["message"]["content"] | |
| def generate_final_answer(self, context_prompt): | |
| """Generate final answer using context.""" | |
| with self.st.chat_message("assistant"): | |
| answer = self.st.write_stream( | |
| self.__stream_answer__(context_prompt, | |
| [{"role": m["role"], "content": m["content"]} for m in self.st.session_state.messages]) | |
| ) | |
| self.st.session_state.messages.append({"role": "assistant", "content": answer}) | |
| self.st.experimental_rerun() | |
| def __handle_context__ (self, response_state: ResponseState) -> str: | |
| """Handle context retrieval based on response state.""" | |
| context_prompt = "" | |
| if response_state.contextType in ("data", "both"): | |
| local_scope = {"df": self.df, "np": np, "pd": pd, "plt": plt, "savefig": self.__safe_savefig__} | |
| exec(response_state.code, {}, local_scope) | |
| fig = plt.gcf() | |
| if fig.get_axes(): # if a chart was generated | |
| with st.chat_message("assistant"): | |
| st.pyplot(fig) | |
| plt.close(fig) | |
| context_prompt = "## CONTEXT DATAFRAME.\n" | |
| context_prompt += str(local_scope.get("result", "")) | |
| if response_state.contextType in ("pdf", "both"): | |
| context_prompt += "## CONTEXT PDF.\n" | |
| results = self.vector_db.similarity_search(response_state.retriverKey, k=5) | |
| for i, doc in enumerate(results, 1): | |
| context_prompt += f"### Document {i}\n{doc.page_content}\n" | |
| return context_prompt | |
| def process_prompt(self, prompt): | |
| """Main pipeline for processing a new user input.""" | |
| self.st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with self.st.chat_message("user"): | |
| self.st.markdown(prompt) | |
| # Step 1: Run refinery prompt | |
| response = self.client.chat( | |
| model=self.model_name, | |
| messages=[{"role": m["role"], "content": m["content"]} | |
| for m in self.st.session_state.messages] + [ | |
| { | |
| "role": "user", | |
| "content": REFINERY_PROMPT.format( | |
| response_format=ResponseState.model_json_schema(), | |
| df_head=self.df.head().to_markdown(), | |
| df_columns=self.df.columns.tolist(), | |
| df_sample=self.df.sample(5).to_markdown() | |
| ) | |
| } | |
| ], | |
| format= ResponseState, | |
| stream=False | |
| ) | |
| response_state: ResponseState = ResponseState.model_validate_json( | |
| response["message"]["content"] | |
| ) | |
| # Step 2: Check if context is needed | |
| if response_state.isNeedContext: | |
| context_prompt = self.__handle_context__(response_state) | |
| self.generate_final_answer(context_prompt) | |
| else: | |
| self.display_final_answer(response_state.response) | |