File size: 3,474 Bytes
18a508e
a9d6f87
18a508e
 
 
cb953d7
18a508e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70db80b
18a508e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from abc import ABC, abstractmethod
from prompt import FINAL_PROMPT
from models import ResponseState
import streamlit as st
import pickle
import matplotlib.pyplot as plt
import io
import numpy as np
import pandas as pd

class FinancialAgentFactory(ABC):
    """Abstract Factory for creating Financial Agents."""
    def __init__(self, st: st, model_name="gpt-4o"):
        self.st = st
        self.df = pickle.load(open("fraudTrainData.pkl", "rb"))
        self.model_name = model_name


        if "messages" not in self.st.session_state:
            self.st.session_state.messages = []
        self.st.session_state["openai_model"] = self.model_name

    def render_header(self, header="Financial Agent"):
        self.st.title(header)
        
    def render_messages(self):
        """Render previous chat messages."""
        for message in self.st.session_state.messages:
            with self.st.chat_message(message["role"]):
                self.st.markdown(message["content"])

    @abstractmethod
    def __stream_answer__(self, instructions, input_messages):
        """Stream answer from the model."""
        pass

    @abstractmethod
    def process_prompt(self, prompt):
        """Main pipeline for processing a new user input."""
        pass
        
    def __safe_savefig__(*args, **kwargs):
        buf = io.BytesIO()
        plt.savefig(buf, format="png")
        buf.seek(0)
        return buf
    
    def __handle_context__(self, response_state: ResponseState) -> str:
        """Handle additional context (data, PDF, etc.)."""
        context_prompt = ""
        if response_state.contextType in ("data", "both"):
            local_scope = {"df": self.df, "np": np, "pd": pd, "plt": plt, "savefig": self.__safe_savefig__}
            exec(response_state.code, {}, local_scope)

            fig = plt.gcf()
            if fig.get_axes():  # if a chart was generated
                with st.chat_message("assistant"):
                    st.pyplot(fig)
                plt.close(fig)

            context_prompt = "## CONTEXT DATAFRAME.\n"
            context_prompt += str(local_scope.get("result", ""))

        # Placeholder for PDF or other context handling
        # elif response_state.contextType in ("pdf", "both"):
        #     context_prompt = "Provide the relevant information from the PDF documents."

        return context_prompt

    def generate_final_answer(self, context_prompt: str):
        """Generate and stream the final answer with context."""
        with st.chat_message("assistant"):
            answer = st.write_stream(
                self.__stream_answer__(
                    instructions=FINAL_PROMPT,
                    input_messages=[
                        {"role": m["role"], "content": m["content"]}
                        for m in st.session_state.messages
                    ] + [{"role": "user", "content": context_prompt}]
                )
            )
        st.session_state.messages.append({"role": "assistant", "content": answer})

    def display_final_answer(self, answer: str):
        """Display a non-streamed assistant answer."""
        st.session_state.messages.append({"role": "assistant", "content": answer})
        with st.chat_message("assistant"):
            st.markdown(answer)

    def run(self):
        """Run the app."""
        self.render_header()
        self.render_messages()

        if prompt := st.chat_input("What is up?"):
            self.process_prompt(prompt)