csds-project / search.py
beatrizpm's picture
Upload 9 files
532f1f0 verified
import re
import torch
import streamlit as st
from sentence_transformers import SentenceTransformer, util
from embeddings.embedder import initialize_embedding_model, initialize_chroma
from config import EMBEDDINGS_DIR, EMBEDDING_MODEL_NAME
from embeddings.latex_to_unicode import LATEX_TO_UNICODE
def decode_latex(text: str) -> str:
for latex, uni in LATEX_TO_UNICODE.items():
text = text.replace(latex, uni)
text = re.sub(r"\\[a-zA-Z]+(\{.*?\})?", "", text)
return text.replace("{", "").replace("}", "").strip()
# Used to calculate sentences similarity in one file
sentence_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
def best_sentence_by_embedding(content: str, query: str):
sentences = [s.strip() for s in re.split(r"(?<=[.!?])\s+", content) if s.strip()]
embeddings = sentence_model.encode(sentences + [query], convert_to_tensor=True)
cosine_scores = util.cos_sim(embeddings[-1], embeddings[:-1])[0]
best_idx = int(torch.argmax(cosine_scores))
return sentences[best_idx], cosine_scores[best_idx].item()
def semantic_search(vectordb, query, k=5):
try:
results_with_scores = vectordb.similarity_search_with_score(query, k=k)
except Exception:
raw_results = vectordb.similarity_search(query, k=k)
results_with_scores = [(r, None) for r in raw_results]
return results_with_scores
def normalize_score(distance):
if distance is None:
return 0.0
return 1 / (1 + distance)
def get_user_input():
query = st.text_input("Enter search query:")
k = st.slider("Number of results", min_value=1, max_value=10, value=5)
return query, k
def truncate_sentence(text: str, max_len: int = 1000) -> str:
return text[:max_len] + ("..." if len(text) > max_len else "")
def process_results(results_with_scores, query):
ranked_results = []
seen_ids = set()
for doc, doc_score in results_with_scores:
metadata = doc.metadata or {}
doc_id = metadata.get("id", "N/A")
if doc_id in seen_ids:
continue
seen_ids.add(doc_id)
categories = metadata.get("categories", "N/A")
year = metadata.get("year", "N/A")
raw_content = decode_latex(doc.page_content)
title = raw_content.split(". ", 1)[0].replace("Title: ", "").strip()
content = raw_content.split("Abstract:", 1)[1].strip()
best_sentence, local_relevance = best_sentence_by_embedding(content, query)
final_score = 0.6 * local_relevance + 0.4 * (1 - doc_score)
ranked_results.append(
{
"doc": doc,
"doc_id": doc_id,
"categories": categories,
"year": year,
"title": title,
"content": content,
"best_sentence": best_sentence,
"local_relevance": local_relevance,
"doc_score": doc_score,
"final_score": final_score,
}
)
return sorted(ranked_results, key=lambda x: x["final_score"], reverse=True)
def display_results(ranked_results):
st.success(f"Top {len(ranked_results)} results found:")
for i, r in enumerate(ranked_results, 1):
content = r["content"]
highlighted_content = content.replace(
r["best_sentence"], f"**{r['best_sentence']}**", 1
)
st.markdown(f"**RESULT {i}:**")
st.markdown(
f"Document ID: {r['doc_id']} | Categories: {r['categories']} | Year: {r['year']} | "
f"Doc Relevance: {1 - (r['doc_score'] if r['doc_score'] else 0):.2f} | "
f"Best Sentence Relevance: {r['local_relevance']:.2f}"
)
st.markdown(f"Title: {r['title']}")
st.markdown(f"Most Relevant Excerpt: {truncate_sentence(highlighted_content)}")
st.markdown("---")
def run_search(embedding_model=None, vectordb=None):
st.header("🔎 Semantic Search")
st.subheader("Search for semantically similar documents")
if embedding_model is None:
embedding_model = initialize_embedding_model()
if vectordb is None:
vectordb = initialize_chroma(embedding_model, EMBEDDINGS_DIR)
if not vectordb:
st.warning("No ChromaDB found. Run embeddings generation first.")
return
query, k = get_user_input()
if not query:
st.info("Type a query above to start searching.")
return
results_with_scores = semantic_search(vectordb, query, k=k * 2)
if not results_with_scores:
st.warning("No results found.")
return
ranked_results = process_results(results_with_scores, query)
display_results(ranked_results)