Financial-RAG / src /FinancialAgentApp.py
mrfirdauss's picture
fix comma
61dd8dc
from abc import ABC, abstractmethod
import pickle
from models import ResponseState
from prompt import REFINERY_PROMPT, FINAL_PROMPT
from langchain_community.vectorstores import FAISS
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from openai import OpenAI
import pickle
import io
import json
from OllamaCustomLocalAPIClient import OllamaCustomLocalAPIClient
class FinancialAgentApp (ABC):
def __init__(self, st, model_name):
self.st = st
self.df = pickle.load(open("fraudTrainData.pkl", "rb"))
self.model_name = model_name
if "messages" not in self.st.session_state:
self.st.session_state.messages = []
def render_header(self):
self.st.title("Financial Agent")
def render_messages(self):
"""Render previous chat messages with roles."""
for message in self.st.session_state.messages:
role = message.get("role", "assistant") # default to assistant if missing
if message.get("type") == "plot":
with self.st.chat_message(role):
self.st.pyplot(message["content"])
else:
with self.st.chat_message(role):
self.st.markdown(message["content"])
@abstractmethod
def __stream_answer__(self, instructions, input_messages):
"""Stream OpenAI response as a generator."""
pass
def process_prompt(self, prompt):
"""Main pipeline for processing a new user input."""
self.st.session_state.messages.append({"role": "user", "content": prompt})
with self.st.chat_message("user"):
self.st.markdown(prompt)
# Step 1: Run refinery prompt
response = self.client.responses.parse(
model=self.model_name,
instructions=REFINERY_PROMPT.format(
df_head=self.df.head().to_markdown(),
df_columns=self.df.columns.tolist(),
df_sample=self.df.sample(5).to_markdown()
),
input=[{"role": m["role"], "content": m["content"]} for m in self.st.session_state.messages],
stream=False,
text_format=ResponseState
)
response_state: ResponseState = response.output_parsed
# Step 2: Check if context is needed
if response_state.isNeedContext:
context_prompt = self.__handle_context__(response_state)
self.generate_final_answer(context_prompt)
else:
self.display_final_answer(response_state.response)
def __safe_savefig__(*args, **kwargs):
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
return buf
@abstractmethod
def __handle_context__(self, response_state: ResponseState) -> str:
"""Handle context if need to add context from data/pdf"""
context_prompt = ""
if response_state.contextType in ("data", "both"):
with self.st.chat_message("assistant"):
self.st.markdown("```python\n{response_state.code}\n```")
self.st.session_state.messages.append({"role": "assistant", "content": "```python\n{response_state.code}\n```"})
local_scope = {"df": self.df, "np": np, "pd": pd, "plt": plt, "savefig": self.__safe_savefig__}
exec(response_state.code, {}, local_scope)
fig = plt.gcf()
if fig.get_axes():
with self.st.chat_message("assistant"):
self.st.pyplot(fig)
self.st.session_state.messages.append({
"role": "assistant",
"type": "plot",
"content": fig
})
plt.close(fig)
context_prompt = "## CONTEXT DATAFRAME.\n"
context_prompt += str(local_scope.get("result", ""))
return context_prompt
def generate_final_answer(self, context_prompt: str):
"""Generate and stream the final answer with context."""
with self.st.chat_message("assistant"):
answer = self.st.write_stream(
self.__stream_answer__(
instructions=FINAL_PROMPT,
input_messages=[
{"role": m["role"], "content": m["content"]} if m['type'] != 'plot' or m['type'] is None else {}
for m in self.st.session_state.messages
] + [{"role": "user", "content": context_prompt}]
)
)
self.st.session_state.messages.append({"role": "assistant", "content": answer})
def display_final_answer(self, answer: str):
"""Display a non-streamed assistant answer."""
self.st.session_state.messages.append({"role": "assistant", "content": answer})
with self.st.chat_message("assistant"):
self.st.markdown(answer)
def run(self):
"""Run the app."""
self.render_header()
self.render_messages()
if prompt := self.st.chat_input("What is up?"):
self.process_prompt(prompt)
class HFFinancialRAG(FinancialAgentApp):
def __init__(self, st, base_url, api_key, model_name = 'Qwen/Qwen3-4B', vector_id="vs_68bf713eea2c81919ac08298a05d6704", embedding=None):
if not base_url:
raise ValueError("base_url cannot be None or empty.")
if not api_key:
raise ValueError("api_key cannot be None or empty.")
super().__init__(st, model_name)
self.client = OpenAI(base_url=base_url, api_key=api_key)
self.vector_db = FAISS.load_local(vector_id, embedding, allow_dangerous_deserialization=True)
def __handle_context__(self, response_state: ResponseState) -> str:
"""Handle additional context (data, PDF, etc.)."""
context_prompt = super().__handle_context__(response_state)
if response_state.contextType in ("pdf", "both"):
context_prompt += "## CONTEXT PDF.\n"
results = self.vector_db.similarity_search(response_state.retriverKey, k=3)
for i, doc in enumerate(results, 1):
context_prompt += f"### Document {i}\n{doc.page_content}\n"
return context_prompt
def __stream_answer__(self, instructions, input_messages):
response_stream = self.client.responses.create(
model=self.model_name,
instructions=instructions,
input=input_messages,
stream=True
)
for chunk in response_stream:
if chunk.type == 'response.output_text.delta':
yield chunk.delta
class OpenAIFinancialRAG(FinancialAgentApp):
def __init__(self, st, model_name = "gpt-5-mini-2025-08-07"):
super().__init__(st, model_name)
self.client = OpenAI()
def __stream_answer__(self, instructions, input_messages):
response_stream = self.client.responses.create(
model=self.model_name,
instructions=instructions,
input=input_messages,
stream=True,
tools=[{
"type": "file_search",
"vector_store_ids": ['vs_68bf713eea2c81919ac08298a05d6704']
}]
)
for chunk in response_stream:
if chunk.type == 'response.output_text.delta':
yield chunk.delta
def __handle_context__(self, response_state: ResponseState):
"""Handle additional context (data, PDF, etc.)."""
context_prompt = super().__handle_context__(response_state)
print('context',context_prompt)
return context_prompt
class OllamaAPIFinancialRAG(FinancialAgentApp):
def __init__(self, st, base_url, model_name = 'qwen3:4b', vector_id="vs_68bf713eea2c81919ac08298a05d6704", embedding=None):
if not base_url:
raise ValueError("api_key cannot be None or empty.")
super().__init__(st, model_name)
self.client = OllamaCustomLocalAPIClient(base_url=base_url, api_key=api_key)
self.vector_db = FAISS.load_local(vector_id, embedding, allow_dangerous_deserialization=True)
def __handle_context__(self, response_state: ResponseState) -> str:
"""Handle additional context (data, PDF, etc.)."""
context_prompt = super().__handle_context(response_state)
if response_state.contextType in ("pdf", "both"):
context_prompt += "## CONTEXT PDF.\n"
results = self.vector_db.similarity_search(response_state.retriverKey, k=3)
for i, doc in enumerate(results, 1):
context_prompt += f"### Document {i}\n{doc.page_content}\n"
return context_prompt
def __stream_answer__(self, instructions, input_messages):
response_stream = self.client.chat(
model=self.model_name,
messages=input_messages + [{"role": "user", "content": instructions}],
stream=True
)
yield response_stream['message']['stream']
def process_prompt(self, prompt):
"""Main pipeline for processing a new user input."""
self.st.session_state.messages.append({"role": "user", "content": prompt})
with self.st.chat_message("user"):
self.st.markdown(prompt)
# Step 1: Run refinery prompt
response = self.client.chat(
model=self.model_name,
messages=[{"role": m["role"], "content": m["content"]} for m in self.st.session_state.messages]
+ [{'role': 'user', 'content': REFINERY_PROMPT.format(
df_head=self.df.head().to_markdown(),
df_columns=self.df.columns.tolist(),
df_sample=self.df.sample(5).to_markdown()
)}],
stream=False,
text_format=ResponseState
)
response_state: ResponseState = ResponseState.model_validate_json(
response["message"]["content"]
)
# Step 2: Check if context is needed
if response_state.isNeedContext:
context_prompt = self.__handle_context__(response_state)
self.generate_final_answer(context_prompt)
else:
self.display_final_answer(response_state.response)