File size: 1,793 Bytes
b02256f
71a3adc
 
ad707c0
2b20076
6e84237
403aa45
2b20076
71a3adc
 
 
351319c
 
71a3adc
 
2b20076
71a3adc
 
ad707c0
71a3adc
ad707c0
 
 
 
71a3adc
ad707c0
71a3adc
351319c
403aa45
71a3adc
 
ad707c0
71a3adc
ad707c0
71a3adc
 
ad707c0
351319c
 
 
 
 
 
 
403aa45
351319c
71a3adc
 
ad707c0
71a3adc
403aa45
ad707c0
79b62bd
403aa45
ad707c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
import pickle

st.title("Fin$mart Chatbot")
st.markdown("Ask financial questions and get answers based on expert knowledge.")

# Load models
@st.cache_resource
def load_models():
    tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
    model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
    embedder = SentenceTransformer("all-MiniLM-L6-v2")
    return tokenizer, model, embedder

tokenizer, model, embedder = load_models()

# Load vector store from existing pickle file
@st.cache_resource
def load_vector_store():
    with open("vectorstore.pkl", "rb") as f:
        index, texts, _ = pickle.load(f)  # We ignore embeddings if not needed
    return index, texts

index, texts = load_vector_store()

# Chat interface
prompt = st.chat_input("Ask something about finance...")

if prompt:
    # Embed query and retrieve top 3 results
    q_embed = embedder.encode([prompt])
    _, I = index.search(q_embed, k=3)
    context = " ".join([texts[i] for i in I[0]])

    # Build input for Flan-T5
    input_text = (
        f"You are a helpful financial assistant. Use the information provided below to answer the user's question.\n\n"
        f"Context: {context}\n\n"
        f"Question: {prompt}\n\n"
        f"Answer:"
    )

    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
    outputs = model.generate(**inputs, max_length=150)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Display response
    st.markdown(f"**Answer:** {answer}")

    # Show retrieved context
    with st.expander("Context Used"):
        for i in I[0]:
            st.write(texts[i])