jayyd's picture
Update app/app.py
27ee790 verified
import streamlit as st
from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import time
import sys
import os
import json
# Local imports
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils.chunking import smart_chunk_text
from utils.retriever import HybridRetriever
from utils.generator import generate_answer
from utils.evaluation import evaluate_response
from utils.guardrails import validate_input, validate_output
from utils.nltk_bootstrap import ensure_punkt
ensure_punkt()
# ---------------------------
# Streamlit Page Config
# ---------------------------
st.set_page_config(page_title="Allstate Financial QA")
st.title("πŸ“Š Allstate Financial QA System")
# ---------------------------
# Cached Loaders
# ---------------------------
@st.cache_resource
def load_retriever():
texts = []
for file in os.listdir("data/processed"):
if file.endswith(".txt") or file.endswith(".json"):
with open(os.path.join("data/processed", file), "r") as f:
texts.append(f.read())
chunks = smart_chunk_text(texts, chunk_size=100)
embedder = SentenceTransformer("all-MiniLM-L6-v2")
return HybridRetriever(chunks, embedder)
@st.cache_resource
def load_finetuned_model():
tokenizer = AutoTokenizer.from_pretrained("jayyd/financial-qa-model")
model = AutoModelForCausalLM.from_pretrained("jayyd/financial-qa-model")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
return pipe
# ---------------------------
# UI Inputs
# ---------------------------
query = st.text_input("Ask a financial question")
method = st.radio("Choose method:", ["RAG", "Fine-Tuned"])
# ---------------------------
# Main App Logic
# ---------------------------
if query:
# Validate input
is_valid, message = validate_input(query)
if not is_valid:
st.error(message)
st.stop()
start_time = time.time()
if method == "RAG":
with st.spinner("Retrieving and generating answer..."):
retriever = load_retriever()
chunks = retriever.retrieve(query)
# Extractive only
answer, supporting_context = generate_answer(query, chunks)
# Evaluate
metrics = evaluate_response(query, answer, chunks)
confidence = metrics.get("confidence", 0.0)
is_valid, message = validate_output(answer, confidence)
# Show answer
st.subheader("Answer")
st.write(answer)
if not is_valid:
st.warning(message)
# Supporting context (expandable)
with st.expander("Supporting Context"):
st.write(supporting_context)
# Sidebar metrics
response_time = time.time() - start_time
st.sidebar.markdown("### Response Metrics")
st.sidebar.markdown(f"Response Time: {response_time:.2f}s")
st.sidebar.markdown(f"Confidence Score: {confidence:.2f}")
st.sidebar.markdown(f"Number of Retrieved Chunks: {metrics.get('num_chunks', 0)}")
st.sidebar.markdown(f"Chunk Relevance Score: {metrics.get('chunk_relevance', 0):.2f}")
else:
with st.spinner("Generating answer from fine-tuned model..."):
pipe = load_finetuned_model()
prompt = f"Q: {query}\nA:"
raw_output = pipe(prompt, max_new_tokens=100)[0]["generated_text"]
# Clean the output: remove prompt part
output = raw_output.split("A:")[-1].strip()
# Evaluate
metrics = evaluate_response(query, output)
confidence = metrics.get("confidence", 0.0)
is_valid, message = validate_output(output, confidence)
# Show answer
st.subheader("Answer")
st.write(output)
if not is_valid:
st.warning(message)
# Sidebar metrics
response_time = time.time() - start_time
st.sidebar.markdown("### Response Metrics")
st.sidebar.markdown(f"Response Time: {response_time:.2f}s")
st.sidebar.markdown(f"Confidence Score: {confidence:.2f}")
st.sidebar.markdown(f"Answer Length: {metrics.get('answer_length', 0)} words")