Spaces:
Sleeping
Sleeping
File size: 1,793 Bytes
b02256f 71a3adc ad707c0 2b20076 6e84237 403aa45 2b20076 71a3adc 351319c 71a3adc 2b20076 71a3adc ad707c0 71a3adc ad707c0 71a3adc ad707c0 71a3adc 351319c 403aa45 71a3adc ad707c0 71a3adc ad707c0 71a3adc ad707c0 351319c 403aa45 351319c 71a3adc ad707c0 71a3adc 403aa45 ad707c0 79b62bd 403aa45 ad707c0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
import pickle
st.title("Fin$mart Chatbot")
st.markdown("Ask financial questions and get answers based on expert knowledge.")
# Load models
@st.cache_resource
def load_models():
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
embedder = SentenceTransformer("all-MiniLM-L6-v2")
return tokenizer, model, embedder
tokenizer, model, embedder = load_models()
# Load vector store from existing pickle file
@st.cache_resource
def load_vector_store():
with open("vectorstore.pkl", "rb") as f:
index, texts, _ = pickle.load(f) # We ignore embeddings if not needed
return index, texts
index, texts = load_vector_store()
# Chat interface
prompt = st.chat_input("Ask something about finance...")
if prompt:
# Embed query and retrieve top 3 results
q_embed = embedder.encode([prompt])
_, I = index.search(q_embed, k=3)
context = " ".join([texts[i] for i in I[0]])
# Build input for Flan-T5
input_text = (
f"You are a helpful financial assistant. Use the information provided below to answer the user's question.\n\n"
f"Context: {context}\n\n"
f"Question: {prompt}\n\n"
f"Answer:"
)
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
outputs = model.generate(**inputs, max_length=150)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Display response
st.markdown(f"**Answer:** {answer}")
# Show retrieved context
with st.expander("Context Used"):
for i in I[0]:
st.write(texts[i]) |