Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| import pickle | |
| from models import ResponseState | |
| from prompt import REFINERY_PROMPT, FINAL_PROMPT | |
| from langchain_community.vectorstores import FAISS | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from openai import OpenAI | |
| import pickle | |
| import io | |
| import json | |
| from OllamaCustomLocalAPIClient import OllamaCustomLocalAPIClient | |
| class FinancialAgentApp (ABC): | |
| def __init__(self, st, model_name): | |
| self.st = st | |
| self.df = pickle.load(open("fraudTrainData.pkl", "rb")) | |
| self.model_name = model_name | |
| if "messages" not in self.st.session_state: | |
| self.st.session_state.messages = [] | |
| def render_header(self): | |
| self.st.title("Financial Agent") | |
| def render_messages(self): | |
| """Render previous chat messages with roles.""" | |
| for message in self.st.session_state.messages: | |
| role = message.get("role", "assistant") # default to assistant if missing | |
| if message.get("type") == "plot": | |
| with self.st.chat_message(role): | |
| self.st.pyplot(message["content"]) | |
| else: | |
| with self.st.chat_message(role): | |
| self.st.markdown(message["content"]) | |
| def __stream_answer__(self, instructions, input_messages): | |
| """Stream OpenAI response as a generator.""" | |
| pass | |
| def process_prompt(self, prompt): | |
| """Main pipeline for processing a new user input.""" | |
| self.st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with self.st.chat_message("user"): | |
| self.st.markdown(prompt) | |
| # Step 1: Run refinery prompt | |
| response = self.client.responses.parse( | |
| model=self.model_name, | |
| instructions=REFINERY_PROMPT.format( | |
| df_head=self.df.head().to_markdown(), | |
| df_columns=self.df.columns.tolist(), | |
| df_sample=self.df.sample(5).to_markdown() | |
| ), | |
| input=[{"role": m["role"], "content": m["content"]} for m in self.st.session_state.messages], | |
| stream=False, | |
| text_format=ResponseState | |
| ) | |
| response_state: ResponseState = response.output_parsed | |
| # Step 2: Check if context is needed | |
| if response_state.isNeedContext: | |
| context_prompt = self.__handle_context__(response_state) | |
| self.generate_final_answer(context_prompt) | |
| else: | |
| self.display_final_answer(response_state.response) | |
| def __safe_savefig__(*args, **kwargs): | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| return buf | |
| def __handle_context__(self, response_state: ResponseState) -> str: | |
| """Handle context if need to add context from data/pdf""" | |
| context_prompt = "" | |
| if response_state.contextType in ("data", "both"): | |
| with self.st.chat_message("assistant"): | |
| self.st.markdown("```python\n{response_state.code}\n```") | |
| self.st.session_state.messages.append({"role": "assistant", "content": "```python\n{response_state.code}\n```"}) | |
| local_scope = {"df": self.df, "np": np, "pd": pd, "plt": plt, "savefig": self.__safe_savefig__} | |
| exec(response_state.code, {}, local_scope) | |
| fig = plt.gcf() | |
| if fig.get_axes(): | |
| with self.st.chat_message("assistant"): | |
| self.st.pyplot(fig) | |
| self.st.session_state.messages.append({ | |
| "role": "assistant", | |
| "type": "plot", | |
| "content": fig | |
| }) | |
| plt.close(fig) | |
| context_prompt = "## CONTEXT DATAFRAME.\n" | |
| context_prompt += str(local_scope.get("result", "")) | |
| return context_prompt | |
| def generate_final_answer(self, context_prompt: str): | |
| """Generate and stream the final answer with context.""" | |
| with self.st.chat_message("assistant"): | |
| answer = self.st.write_stream( | |
| self.__stream_answer__( | |
| instructions=FINAL_PROMPT, | |
| input_messages=[ | |
| {"role": m["role"], "content": m["content"]} if m['type'] != 'plot' or m['type'] is None else {} | |
| for m in self.st.session_state.messages | |
| ] + [{"role": "user", "content": context_prompt}] | |
| ) | |
| ) | |
| self.st.session_state.messages.append({"role": "assistant", "content": answer}) | |
| def display_final_answer(self, answer: str): | |
| """Display a non-streamed assistant answer.""" | |
| self.st.session_state.messages.append({"role": "assistant", "content": answer}) | |
| with self.st.chat_message("assistant"): | |
| self.st.markdown(answer) | |
| def run(self): | |
| """Run the app.""" | |
| self.render_header() | |
| self.render_messages() | |
| if prompt := self.st.chat_input("What is up?"): | |
| self.process_prompt(prompt) | |
| class HFFinancialRAG(FinancialAgentApp): | |
| def __init__(self, st, base_url, api_key, model_name = 'Qwen/Qwen3-4B', vector_id="vs_68bf713eea2c81919ac08298a05d6704", embedding=None): | |
| if not base_url: | |
| raise ValueError("base_url cannot be None or empty.") | |
| if not api_key: | |
| raise ValueError("api_key cannot be None or empty.") | |
| super().__init__(st, model_name) | |
| self.client = OpenAI(base_url=base_url, api_key=api_key) | |
| self.vector_db = FAISS.load_local(vector_id, embedding, allow_dangerous_deserialization=True) | |
| def __handle_context__(self, response_state: ResponseState) -> str: | |
| """Handle additional context (data, PDF, etc.).""" | |
| context_prompt = super().__handle_context__(response_state) | |
| if response_state.contextType in ("pdf", "both"): | |
| context_prompt += "## CONTEXT PDF.\n" | |
| results = self.vector_db.similarity_search(response_state.retriverKey, k=3) | |
| for i, doc in enumerate(results, 1): | |
| context_prompt += f"### Document {i}\n{doc.page_content}\n" | |
| return context_prompt | |
| def __stream_answer__(self, instructions, input_messages): | |
| response_stream = self.client.responses.create( | |
| model=self.model_name, | |
| instructions=instructions, | |
| input=input_messages, | |
| stream=True | |
| ) | |
| for chunk in response_stream: | |
| if chunk.type == 'response.output_text.delta': | |
| yield chunk.delta | |
| class OpenAIFinancialRAG(FinancialAgentApp): | |
| def __init__(self, st, model_name = "gpt-5-mini-2025-08-07"): | |
| super().__init__(st, model_name) | |
| self.client = OpenAI() | |
| def __stream_answer__(self, instructions, input_messages): | |
| response_stream = self.client.responses.create( | |
| model=self.model_name, | |
| instructions=instructions, | |
| input=input_messages, | |
| stream=True, | |
| tools=[{ | |
| "type": "file_search", | |
| "vector_store_ids": ['vs_68bf713eea2c81919ac08298a05d6704'] | |
| }] | |
| ) | |
| for chunk in response_stream: | |
| if chunk.type == 'response.output_text.delta': | |
| yield chunk.delta | |
| def __handle_context__(self, response_state: ResponseState): | |
| """Handle additional context (data, PDF, etc.).""" | |
| context_prompt = super().__handle_context__(response_state) | |
| print('context',context_prompt) | |
| return context_prompt | |
| class OllamaAPIFinancialRAG(FinancialAgentApp): | |
| def __init__(self, st, base_url, model_name = 'qwen3:4b', vector_id="vs_68bf713eea2c81919ac08298a05d6704", embedding=None): | |
| if not base_url: | |
| raise ValueError("api_key cannot be None or empty.") | |
| super().__init__(st, model_name) | |
| self.client = OllamaCustomLocalAPIClient(base_url=base_url, api_key=api_key) | |
| self.vector_db = FAISS.load_local(vector_id, embedding, allow_dangerous_deserialization=True) | |
| def __handle_context__(self, response_state: ResponseState) -> str: | |
| """Handle additional context (data, PDF, etc.).""" | |
| context_prompt = super().__handle_context(response_state) | |
| if response_state.contextType in ("pdf", "both"): | |
| context_prompt += "## CONTEXT PDF.\n" | |
| results = self.vector_db.similarity_search(response_state.retriverKey, k=3) | |
| for i, doc in enumerate(results, 1): | |
| context_prompt += f"### Document {i}\n{doc.page_content}\n" | |
| return context_prompt | |
| def __stream_answer__(self, instructions, input_messages): | |
| response_stream = self.client.chat( | |
| model=self.model_name, | |
| messages=input_messages + [{"role": "user", "content": instructions}], | |
| stream=True | |
| ) | |
| yield response_stream['message']['stream'] | |
| def process_prompt(self, prompt): | |
| """Main pipeline for processing a new user input.""" | |
| self.st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with self.st.chat_message("user"): | |
| self.st.markdown(prompt) | |
| # Step 1: Run refinery prompt | |
| response = self.client.chat( | |
| model=self.model_name, | |
| messages=[{"role": m["role"], "content": m["content"]} for m in self.st.session_state.messages] | |
| + [{'role': 'user', 'content': REFINERY_PROMPT.format( | |
| df_head=self.df.head().to_markdown(), | |
| df_columns=self.df.columns.tolist(), | |
| df_sample=self.df.sample(5).to_markdown() | |
| )}], | |
| stream=False, | |
| text_format=ResponseState | |
| ) | |
| response_state: ResponseState = ResponseState.model_validate_json( | |
| response["message"]["content"] | |
| ) | |
| # Step 2: Check if context is needed | |
| if response_state.isNeedContext: | |
| context_prompt = self.__handle_context__(response_state) | |
| self.generate_final_answer(context_prompt) | |
| else: | |
| self.display_final_answer(response_state.response) | |