QALocalLLM / src /FinancialAgent.py
mrfirdauss's picture
fix: change ollama to api
70db80b
from abc import ABC, abstractmethod
from prompt import FINAL_PROMPT
from models import ResponseState
import streamlit as st
import pickle
import matplotlib.pyplot as plt
import io
import numpy as np
import pandas as pd
class FinancialAgentFactory(ABC):
"""Abstract Factory for creating Financial Agents."""
def __init__(self, st: st, model_name="gpt-4o"):
self.st = st
self.df = pickle.load(open("fraudTrainData.pkl", "rb"))
self.model_name = model_name
if "messages" not in self.st.session_state:
self.st.session_state.messages = []
self.st.session_state["openai_model"] = self.model_name
def render_header(self, header="Financial Agent"):
self.st.title(header)
def render_messages(self):
"""Render previous chat messages."""
for message in self.st.session_state.messages:
with self.st.chat_message(message["role"]):
self.st.markdown(message["content"])
@abstractmethod
def __stream_answer__(self, instructions, input_messages):
"""Stream answer from the model."""
pass
@abstractmethod
def process_prompt(self, prompt):
"""Main pipeline for processing a new user input."""
pass
def __safe_savefig__(*args, **kwargs):
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
return buf
def __handle_context__(self, response_state: ResponseState) -> str:
"""Handle additional context (data, PDF, etc.)."""
context_prompt = ""
if response_state.contextType in ("data", "both"):
local_scope = {"df": self.df, "np": np, "pd": pd, "plt": plt, "savefig": self.__safe_savefig__}
exec(response_state.code, {}, local_scope)
fig = plt.gcf()
if fig.get_axes(): # if a chart was generated
with st.chat_message("assistant"):
st.pyplot(fig)
plt.close(fig)
context_prompt = "## CONTEXT DATAFRAME.\n"
context_prompt += str(local_scope.get("result", ""))
# Placeholder for PDF or other context handling
# elif response_state.contextType in ("pdf", "both"):
# context_prompt = "Provide the relevant information from the PDF documents."
return context_prompt
def generate_final_answer(self, context_prompt: str):
"""Generate and stream the final answer with context."""
with st.chat_message("assistant"):
answer = st.write_stream(
self.__stream_answer__(
instructions=FINAL_PROMPT,
input_messages=[
{"role": m["role"], "content": m["content"]}
for m in st.session_state.messages
] + [{"role": "user", "content": context_prompt}]
)
)
st.session_state.messages.append({"role": "assistant", "content": answer})
def display_final_answer(self, answer: str):
"""Display a non-streamed assistant answer."""
st.session_state.messages.append({"role": "assistant", "content": answer})
with st.chat_message("assistant"):
st.markdown(answer)
def run(self):
"""Run the app."""
self.render_header()
self.render_messages()
if prompt := st.chat_input("What is up?"):
self.process_prompt(prompt)