QALocalLLM / src /FinancialAgentOllama.py
mrfirdauss's picture
fix import
74c0b2d
from FinancialAgent import FinancialAgentFactory
from prompt import REFINERY_PROMPT
from models import ResponseState
import numpy as np
import pandas as pd
import streamlit as st
import matplotlib.pyplot as plt
from langchain_community.vectorstores import FAISS
from OllamaAPIClient import OllamaAPIClient
class FinancialAgentOllama(FinancialAgentFactory):
"""Concrete Financial Agent using Ollama."""
def __init__(self, st, model_name="qwen3:4b", url="https://mrfirdauss-ollama-api.hf.space", embedding=None):
super().__init__(st, model_name)
self.client = OllamaAPIClient(url)
self.vector_db = FAISS.load_local("vs_68bf713eea2c81919ac08298a05d6704", embedding, allow_dangerous_deserialization=True)
def __stream_answer__(self, instructions, input_messages):
response_stream = self.client.chat(
model=self.model_name,
messages=input_messages + [{"role": "user", "content": instructions}],
stream=True
)
for chunk in response_stream:
if "message" in chunk and "content" in chunk["message"]:
yield chunk["message"]["content"]
def generate_final_answer(self, context_prompt):
"""Generate final answer using context."""
with self.st.chat_message("assistant"):
answer = self.st.write_stream(
self.__stream_answer__(context_prompt,
[{"role": m["role"], "content": m["content"]} for m in self.st.session_state.messages])
)
self.st.session_state.messages.append({"role": "assistant", "content": answer})
self.st.experimental_rerun()
def __handle_context__ (self, response_state: ResponseState) -> str:
"""Handle context retrieval based on response state."""
context_prompt = ""
if response_state.contextType in ("data", "both"):
local_scope = {"df": self.df, "np": np, "pd": pd, "plt": plt, "savefig": self.__safe_savefig__}
exec(response_state.code, {}, local_scope)
fig = plt.gcf()
if fig.get_axes(): # if a chart was generated
with st.chat_message("assistant"):
st.pyplot(fig)
plt.close(fig)
context_prompt = "## CONTEXT DATAFRAME.\n"
context_prompt += str(local_scope.get("result", ""))
if response_state.contextType in ("pdf", "both"):
context_prompt += "## CONTEXT PDF.\n"
results = self.vector_db.similarity_search(response_state.retriverKey, k=5)
for i, doc in enumerate(results, 1):
context_prompt += f"### Document {i}\n{doc.page_content}\n"
return context_prompt
def process_prompt(self, prompt):
"""Main pipeline for processing a new user input."""
self.st.session_state.messages.append({"role": "user", "content": prompt})
with self.st.chat_message("user"):
self.st.markdown(prompt)
# Step 1: Run refinery prompt
response = self.client.chat(
model=self.model_name,
messages=[{"role": m["role"], "content": m["content"]}
for m in self.st.session_state.messages] + [
{
"role": "user",
"content": REFINERY_PROMPT.format(
response_format=ResponseState.model_json_schema(),
df_head=self.df.head().to_markdown(),
df_columns=self.df.columns.tolist(),
df_sample=self.df.sample(5).to_markdown()
)
}
],
format= ResponseState,
stream=False
)
response_state: ResponseState = ResponseState.model_validate_json(
response["message"]["content"]
)
# Step 2: Check if context is needed
if response_state.isNeedContext:
context_prompt = self.__handle_context__(response_state)
self.generate_final_answer(context_prompt)
else:
self.display_final_answer(response_state.response)