|
|
import streamlit as st |
|
|
import os |
|
|
import glob |
|
|
import torch |
|
|
import faiss |
|
|
import numpy as np |
|
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from sentence_transformers import CrossEncoder |
|
|
import pickle |
|
|
import chromadb |
|
|
from chromadb.utils import embedding_functions |
|
|
from huggingface_hub import login |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
if hf_token: |
|
|
try: |
|
|
|
|
|
login(token=hf_token) |
|
|
st.success("Hugging Face token successfully validated (or cached).") |
|
|
except Exception as e: |
|
|
|
|
|
st.error(f"Hugging Face login failed (check token validity): {e}") |
|
|
|
|
|
os.environ["HF_TOKEN"] = hf_token |
|
|
os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token |
|
|
else: |
|
|
st.warning("HF_TOKEN not found. Using anonymous access (may lead to 403 for private models).") |
|
|
|
|
|
BASE_DIR = "/tmp" |
|
|
os.makedirs(BASE_DIR, exist_ok=True) |
|
|
|
|
|
collected_file = f"{BASE_DIR}/collected_data.txt" |
|
|
vector_db_file = f"{BASE_DIR}/vector_db.faiss" |
|
|
embedding_file = f"{BASE_DIR}/embeddings.npy" |
|
|
chunks_file = f"{BASE_DIR}/chunks.pkl" |
|
|
emb_choice_file = f"{BASE_DIR}/embedding_choice.txt" |
|
|
index_choice_file = f"{BASE_DIR}/index_choice.txt" |
|
|
chroma_dir = f"{BASE_DIR}/chroma_db" |
|
|
|
|
|
os.makedirs(chroma_dir, exist_ok=True) |
|
|
|
|
|
def bert_encode(model,tokenizer,texts, batch_size=300, device="cpu"): |
|
|
model.to(device) |
|
|
all_embeddings = [] |
|
|
with torch.no_grad(): |
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch_texts = texts[i:i+batch_size] |
|
|
inputs = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device) |
|
|
outputs = model(**inputs) |
|
|
embeddings = outputs.last_hidden_state.mean(dim=1) |
|
|
all_embeddings.append(embeddings.cpu().numpy().astype("float16")) |
|
|
return np.vstack(all_embeddings) |
|
|
|
|
|
st.title("Retriever Web Application") |
|
|
st.header("Collect Data -> Chunking & Embedding -> Vector DB Creation -> Inquiry & Reranking") |
|
|
|
|
|
|
|
|
tab1, tab2, tab3 = st.tabs(["Collect Data", "DB Formation", "Inquiry Vector DB"]) |
|
|
|
|
|
with tab1: |
|
|
st.header("Collect Data") |
|
|
|
|
|
|
|
|
text_input = st.text_area( |
|
|
"Paste your data file content here:", |
|
|
height=300, |
|
|
placeholder="Paste your text data here..." |
|
|
) |
|
|
|
|
|
collected_file_path = collected_file |
|
|
|
|
|
|
|
|
collect_button_pressed = st.button("Collect", key="collect_data_button_tab1") |
|
|
|
|
|
if collect_button_pressed and text_input: |
|
|
|
|
|
all_text = text_input |
|
|
|
|
|
|
|
|
with open(collected_file_path, "w", encoding="utf-8") as f: |
|
|
f.write(all_text) |
|
|
|
|
|
st.success("Collected 1 file's content successfully!") |
|
|
|
|
|
elif collect_button_pressed and not text_input: |
|
|
st.warning("Please paste content into the text area before clicking Collect.") |
|
|
|
|
|
else: |
|
|
st.write("Waiting for data input...") |
|
|
|
|
|
|
|
|
with tab2: |
|
|
st.header("Vector DB Formation") |
|
|
chunk_size = st.number_input("Chunk size:", 50, 1000, 200, step=50) |
|
|
overlap = st.number_input("Overlap size:", 0, 500, 50, step=10) |
|
|
embedding_choice = st.selectbox("Embedding Technique", ["SentencePiece", "TF-IDF", "BERT"]) |
|
|
index_choice = st.selectbox("Vector DB", ["FAISS","ChromaDB"]) |
|
|
embeddings = None |
|
|
if st.button("Create DB"): |
|
|
with open(collected_file, "r", encoding="utf-8") as f: |
|
|
text_data = f.read() |
|
|
chunks = [text_data[i:i+chunk_size] for i in range(0, len(text_data), chunk_size-overlap)] |
|
|
|
|
|
if embedding_choice == "SentencePiece": |
|
|
model = SentenceTransformer("all-MiniLM-L6-v2",use_auth_token=hf_token) |
|
|
embeddings = model.encode(chunks, batch_size=300) |
|
|
elif embedding_choice == "TF-IDF": |
|
|
vectorizer = TfidfVectorizer() |
|
|
embeddings = vectorizer.fit_transform(chunks).toarray() |
|
|
elif embedding_choice == "BERT": |
|
|
model_name = "bert-base-uncased" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name,token=hf_token) |
|
|
model = AutoModel.from_pretrained(model_name,token=hf_token) |
|
|
embeddings = bert_encode(model,tokenizer,chunks) |
|
|
|
|
|
if index_choice == "FAISS": |
|
|
dim = len(embeddings[0]) |
|
|
index = faiss.IndexFlatL2(dim) |
|
|
index.add(np.array(embeddings).astype("float32")) |
|
|
faiss.write_index(index, vector_db_file) |
|
|
np.save(embedding_file, embeddings) |
|
|
else: |
|
|
|
|
|
client = chromadb.PersistentClient(path=chroma_dir) |
|
|
try: |
|
|
client.delete_collection("rag_collection") |
|
|
except: |
|
|
pass |
|
|
collection = client.get_or_create_collection("rag_collection") |
|
|
collection.add( |
|
|
documents=chunks, |
|
|
embeddings=embeddings, |
|
|
ids=[str(i) for i in range(len(chunks))] |
|
|
) |
|
|
|
|
|
|
|
|
with open(chunks_file, "wb") as f: |
|
|
pickle.dump(chunks, f) |
|
|
with open(emb_choice_file, "w") as f: |
|
|
f.write(embedding_choice) |
|
|
with open(index_choice_file, "w") as f: |
|
|
f.write(index_choice) |
|
|
|
|
|
st.write(f"Saved embeddings with shape: {embeddings.shape}") |
|
|
|
|
|
|
|
|
with tab3: |
|
|
st.header("Inquiry Vector DB") |
|
|
user_query = st.text_area("User Query") |
|
|
expert_answer = st.text_area("Expert Answer") |
|
|
k = st.number_input("Number of retrieved data (k):", 1, 20, 5, step=1) |
|
|
|
|
|
|
|
|
|
|
|
if st.button("Search"): |
|
|
|
|
|
with open(chunks_file, "rb") as f: |
|
|
chunks = pickle.load(f) |
|
|
with open(emb_choice_file, "r") as f: |
|
|
embedding_choice = f.read().strip() |
|
|
with open(index_choice_file, "r") as f: |
|
|
index_choice = f.read().strip() |
|
|
|
|
|
st.header(f"Using Embedding: {embedding_choice}, Index: {index_choice}") |
|
|
query_emb = None |
|
|
if embedding_choice == "SentencePiece": |
|
|
model = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
query_emb = model.encode([user_query]) |
|
|
elif embedding_choice == "TF-IDF": |
|
|
vectorizer = TfidfVectorizer() |
|
|
vectorizer.fit(chunks) |
|
|
query_emb = vectorizer.transform([user_query]).toarray() |
|
|
elif embedding_choice == "BERT": |
|
|
model_name = "bert-base-uncased" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModel.from_pretrained(model_name) |
|
|
query_emb = bert_encode(model,tokenizer,[user_query]) |
|
|
|
|
|
if index_choice == "ChromaDB": |
|
|
|
|
|
st.write("Using ChromaDB with cosine similarity. In cosine similarity, a score closer to 1 means more similarity. " \ |
|
|
"Conversely, a score closer to 0 means less similarity." \ |
|
|
"Cosine similarity scores range from -1 to 1, where 1 indicates perfect similarity, 0 indicates no similarity, and -1 indicates " \ |
|
|
"perfect dissimilarity.") |
|
|
|
|
|
client = chromadb.PersistentClient(path=chroma_dir) |
|
|
collection = client.get_or_create_collection("rag_collection") |
|
|
results = collection.query( |
|
|
query_embeddings=query_emb.tolist(), |
|
|
n_results=k, |
|
|
include=["documents", "distances"] |
|
|
) |
|
|
|
|
|
retrieved_texts = results["documents"][0] |
|
|
retrieved_scores = results["distances"][0] |
|
|
|
|
|
st.subheader("Retrieved texts and scores:") |
|
|
for doc, score in zip(retrieved_texts, retrieved_scores): |
|
|
st.markdown(f"**Score:** {score:.4f}") |
|
|
st.write(doc) |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
st.write("Using FAISS with L2 distance In L2 distance, a smaller score means more similarity." \ |
|
|
"L2 distance scores range from 0 to infinity, where 0 indicates perfect similarity (identical vectors), " |
|
|
"and larger values indicate less similarity.") |
|
|
|
|
|
index = faiss.read_index(vector_db_file) |
|
|
D, I = index.search(query_emb.astype("float32"), k) |
|
|
|
|
|
retrieved_texts = [chunks[i] for i in I[0]] |
|
|
st.subheader("Retrieved texts and scores:") |
|
|
for doc, score in zip(retrieved_texts, D[0]): |
|
|
st.markdown(f"**Score:** {score:.4f}") |
|
|
st.write(doc) |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
st.write("Reranking using Cross-ReRank (higher score means more relevance, and lower score means less relevance). " \ |
|
|
"It is relative ranking (higher score = more relevant), not the absolute magnitude.") |
|
|
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2",token=hf_token) |
|
|
scores = [reranker.predict([(user_query, doc)])[0] for doc in retrieved_texts] |
|
|
st.subheader("Reranked scores:") |
|
|
for doc, score in zip(retrieved_texts, scores): |
|
|
st.markdown(f"**Rerank Score:** {score:.4f}") |
|
|
st.write(doc) |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
if expert_answer.strip(): |
|
|
relevance_scores = [reranker.predict([(expert_answer, doc)])[0] for doc in retrieved_texts] |
|
|
st.subheader("Relevance to Expert Answer:") |
|
|
for doc, score in zip(retrieved_texts, relevance_scores): |
|
|
st.markdown(f"**Relevance Score:** {score:.4f}") |
|
|
st.write(doc) |
|
|
st.markdown("---") |