mydemoapp / src /streamlit_app.py
Soha85's picture
edit ui
40471fa verified
import streamlit as st
import os
import glob
import torch
import faiss
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import CrossEncoder
import pickle
import chromadb
from chromadb.utils import embedding_functions
from huggingface_hub import login
hf_token = os.getenv("HF_TOKEN")
if hf_token:
try:
# Use the official login function
login(token=hf_token)
st.success("Hugging Face token successfully validated (or cached).")
except Exception as e:
# This might reveal a more specific error than AxiosError
st.error(f"Hugging Face login failed (check token validity): {e}")
# Still set the environment variables for downstream libraries
os.environ["HF_TOKEN"] = hf_token
os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
else:
st.warning("HF_TOKEN not found. Using anonymous access (may lead to 403 for private models).")
BASE_DIR = "/tmp"
os.makedirs(BASE_DIR, exist_ok=True)
# Global variables
collected_file = f"{BASE_DIR}/collected_data.txt"
vector_db_file = f"{BASE_DIR}/vector_db.faiss"
embedding_file = f"{BASE_DIR}/embeddings.npy"
chunks_file = f"{BASE_DIR}/chunks.pkl"
emb_choice_file = f"{BASE_DIR}/embedding_choice.txt"
index_choice_file = f"{BASE_DIR}/index_choice.txt"
chroma_dir = f"{BASE_DIR}/chroma_db"
os.makedirs(chroma_dir, exist_ok=True)
def bert_encode(model,tokenizer,texts, batch_size=300, device="cpu"):
model.to(device)
all_embeddings = []
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
inputs = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
all_embeddings.append(embeddings.cpu().numpy().astype("float16")) # save memory
return np.vstack(all_embeddings)
st.title("Retriever Web Application")
st.header("Collect Data -> Chunking & Embedding -> Vector DB Creation -> Inquiry & Reranking")
# Tab 1: Collect Data
tab1, tab2, tab3 = st.tabs(["Collect Data", "DB Formation", "Inquiry Vector DB"])
with tab1:
st.header("Collect Data")
# NEW: Use a text area for data input
text_input = st.text_area(
"Paste your data file content here:",
height=300,
placeholder="Paste your text data here..."
)
collected_file_path = collected_file
# 🌟 FIX: Pass a unique 'key' argument to the button
collect_button_pressed = st.button("Collect", key="collect_data_button_tab1")
if collect_button_pressed and text_input:
all_text = text_input
# Save the content to the temporary file
with open(collected_file_path, "w", encoding="utf-8") as f:
f.write(all_text)
st.success("Collected 1 file's content successfully!")
elif collect_button_pressed and not text_input:
st.warning("Please paste content into the text area before clicking Collect.")
else:
st.write("Waiting for data input...")
# Tab 2: DB Formation
with tab2:
st.header("Vector DB Formation")
chunk_size = st.number_input("Chunk size:", 50, 1000, 200, step=50)
overlap = st.number_input("Overlap size:", 0, 500, 50, step=10)
embedding_choice = st.selectbox("Embedding Technique", ["SentencePiece", "TF-IDF", "BERT"])
index_choice = st.selectbox("Vector DB", ["FAISS","ChromaDB"])
embeddings = None
if st.button("Create DB"):
with open(collected_file, "r", encoding="utf-8") as f:
text_data = f.read()
chunks = [text_data[i:i+chunk_size] for i in range(0, len(text_data), chunk_size-overlap)]
if embedding_choice == "SentencePiece":
model = SentenceTransformer("all-MiniLM-L6-v2",use_auth_token=hf_token)
embeddings = model.encode(chunks, batch_size=300)
elif embedding_choice == "TF-IDF":
vectorizer = TfidfVectorizer()
embeddings = vectorizer.fit_transform(chunks).toarray()
elif embedding_choice == "BERT":
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name,token=hf_token)
model = AutoModel.from_pretrained(model_name,token=hf_token)
embeddings = bert_encode(model,tokenizer,chunks)
if index_choice == "FAISS":
dim = len(embeddings[0])
index = faiss.IndexFlatL2(dim)
index.add(np.array(embeddings).astype("float32"))
faiss.write_index(index, vector_db_file)
np.save(embedding_file, embeddings)
else: # ChromaDB
# client = chromadb.PersistentClient(path="chroma_db")
client = chromadb.PersistentClient(path=chroma_dir)
try:
client.delete_collection("rag_collection")
except:
pass
collection = client.get_or_create_collection("rag_collection")
collection.add(
documents=chunks,
embeddings=embeddings,
ids=[str(i) for i in range(len(chunks))]
)
with open(chunks_file, "wb") as f:
pickle.dump(chunks, f)
with open(emb_choice_file, "w") as f:
f.write(embedding_choice)
with open(index_choice_file, "w") as f:
f.write(index_choice)
st.write(f"Saved embeddings with shape: {embeddings.shape}")
# Tab 3: Inquiry Vector DB
with tab3:
st.header("Inquiry Vector DB")
user_query = st.text_area("User Query")
expert_answer = st.text_area("Expert Answer")
k = st.number_input("Number of retrieved data (k):", 1, 20, 5, step=1)
#similarity_metric = st.selectbox("Similarity", ["cosine", "euclidean"])
if st.button("Search"):
# Load chunks and embedding choice and index choice
with open(chunks_file, "rb") as f:
chunks = pickle.load(f)
with open(emb_choice_file, "r") as f:
embedding_choice = f.read().strip()
with open(index_choice_file, "r") as f:
index_choice = f.read().strip()
#display embedding choice and index choice
st.header(f"Using Embedding: {embedding_choice}, Index: {index_choice}")
query_emb = None
if embedding_choice == "SentencePiece":
model = SentenceTransformer("all-MiniLM-L6-v2")
query_emb = model.encode([user_query])
elif embedding_choice == "TF-IDF":
vectorizer = TfidfVectorizer()
vectorizer.fit(chunks) # fit on same chunks
query_emb = vectorizer.transform([user_query]).toarray()
elif embedding_choice == "BERT":
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
query_emb = bert_encode(model,tokenizer,[user_query])
if index_choice == "ChromaDB":
#display similarity score measure used by chromadb and illustrate what number of score means more similar and its range
st.write("Using ChromaDB with cosine similarity. In cosine similarity, a score closer to 1 means more similarity. " \
"Conversely, a score closer to 0 means less similarity." \
"Cosine similarity scores range from -1 to 1, where 1 indicates perfect similarity, 0 indicates no similarity, and -1 indicates " \
"perfect dissimilarity.")
client = chromadb.PersistentClient(path=chroma_dir)
collection = client.get_or_create_collection("rag_collection")
results = collection.query(
query_embeddings=query_emb.tolist(),
n_results=k,
include=["documents", "distances"]
)
#display retrieved texts and scores beside each other
retrieved_texts = results["documents"][0]
retrieved_scores = results["distances"][0]
st.subheader("Retrieved texts and scores:")
for doc, score in zip(retrieved_texts, retrieved_scores):
st.markdown(f"**Score:** {score:.4f}")
st.write(doc)
st.markdown("---")
else: # FAISS
#display similarity score measure used by FAISS and illustrate what number of score means more similar and its range
st.write("Using FAISS with L2 distance In L2 distance, a smaller score means more similarity." \
"L2 distance scores range from 0 to infinity, where 0 indicates perfect similarity (identical vectors), "
"and larger values indicate less similarity.")
index = faiss.read_index(vector_db_file)
D, I = index.search(query_emb.astype("float32"), k)
#display retrieved texts and scores beside each other
retrieved_texts = [chunks[i] for i in I[0]]
st.subheader("Retrieved texts and scores:")
for doc, score in zip(retrieved_texts, D[0]):
st.markdown(f"**Score:** {score:.4f}")
st.write(doc)
st.markdown("---")
# Reranking
#display similarity score measure used by ReRank and illustrate what number of score means more similar and its range
st.write("Reranking using Cross-ReRank (higher score means more relevance, and lower score means less relevance). " \
"It is relative ranking (higher score = more relevant), not the absolute magnitude.")
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2",token=hf_token)
scores = [reranker.predict([(user_query, doc)])[0] for doc in retrieved_texts]
st.subheader("Reranked scores:")
for doc, score in zip(retrieved_texts, scores):
st.markdown(f"**Rerank Score:** {score:.4f}")
st.write(doc)
st.markdown("---")
#meausre relevance if expert answer is provided
if expert_answer.strip():
relevance_scores = [reranker.predict([(expert_answer, doc)])[0] for doc in retrieved_texts]
st.subheader("Relevance to Expert Answer:")
for doc, score in zip(retrieved_texts, relevance_scores):
st.markdown(f"**Relevance Score:** {score:.4f}")
st.write(doc)
st.markdown("---")