demo-embeddings / src /streamlit_app.py
Geraldine's picture
Update src/streamlit_app.py
6294992 verified
import streamlit as st
import torch
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
import plotly.express as px
from sentence_transformers import SentenceTransformer
@st.cache_resource(show_spinner=True)
def load_model():
return SentenceTransformer("intfloat/multilingual-e5-small")
emb_model = load_model()
EXAMPLE_SENTENCES = [
"LiDAR-based domain generalization and unknown 3D object detection",
"Information balance and density in augmented reality assistance",
"Learning Semantics and Geometry for Scene Understanding",
"Integrating Expert Knowledge with Deep Reinforcement Learning Methods for Autonomous Driving",
"Deep Reinforcement Learning and Learning from Demonstrations for Robot Manipulators",
"Smart vehicule trajectory prediction in various autonomous driving scenarios",
"Exploring LiDAR Odometries through Classical, Deep and Inertial perspectives",
"Généralisation de domaine pour la segmentation sémantique de données LiDAR pour le véhicule autonome",
"Localisation en intérieur par SLAM magnéto-visuel-inertiel",
"Learned and Hybrid Strategies for Control and Planning of Highly Automated Vehicles",
]
st.set_page_config(page_title="Sentence search – multilingual-e5-small", layout="wide")
st.title("🔎 Sentence search with multilingual-e5-small")
st.write("Type a query and find the most similar research titles.")
@st.cache_resource(show_spinner=False)
def compute_corpus_embeddings(_model: SentenceTransformer, sents: list[str]) -> torch.Tensor:
# _model is ignored for caching purposes (only sents is hashed)
return _model.encode(sents, convert_to_tensor=True)
# -----------------------------
# Corpus configuration (manual text or file upload)
# -----------------------------
st.subheader("Corpus configuration")
default_corpus_text = "\n".join(EXAMPLE_SENTENCES)
col_left, col_right = st.columns(2)
with col_left:
manual_text = st.text_area(
"Sentences to embed (one per line)",
value=default_corpus_text,
height=220,
)
with col_right:
uploaded_file = st.file_uploader(
"Or upload a corpus file (CSV, XLSX, JSON)",
type=["csv", "xlsx", "xls", "json"],
)
df = None
corpus_sentences: list[str] = []
corpus_source_desc = ""
if uploaded_file is not None:
try:
name = uploaded_file.name.lower()
if name.endswith(".csv"):
df = pd.read_csv(uploaded_file)
elif name.endswith(".xlsx") or name.endswith(".xls"):
df = pd.read_excel(uploaded_file)
elif name.endswith(".json"):
df = pd.read_json(uploaded_file)
except Exception as e:
st.error(f"Failed to read uploaded file: {e}")
if df is not None and not df.empty:
st.markdown("**Uploaded data preview**")
st.dataframe(df.head(100), use_container_width=True)
text_columns = df.select_dtypes(include=["object", "string"]).columns.tolist()
if not text_columns:
st.warning("No text columns detected in the uploaded file. Falling back to manual sentences.")
else:
selected_cols = st.multiselect(
"Columns to concatenate as text (per row)",
text_columns,
default=text_columns[:1],
)
if selected_cols:
combined = df[selected_cols].astype(str).fillna("")
corpus_sentences = (
combined.apply(
lambda row: " ".join(v.strip() for v in row if isinstance(v, str) and v.strip()),
axis=1,
)
.tolist()
)
corpus_sentences = [s for s in corpus_sentences if s]
corpus_source_desc = f"Uploaded file – columns: {', '.join(selected_cols)} (rows: {len(corpus_sentences)})"
else:
st.info("Select at least one column to build the corpus. Using manual sentences instead.")
if not corpus_sentences:
corpus_sentences = [s.strip() for s in manual_text.splitlines() if s.strip()]
corpus_source_desc = f"Manual sentences (count: {len(corpus_sentences)})"
if not corpus_sentences:
st.warning("No sentences available to embed. Please add sentences or upload a file.")
else:
st.caption(corpus_source_desc)
corpus_embeddings = compute_corpus_embeddings(emb_model, corpus_sentences)
query = st.text_input("Query", value="autonomous vehicle and image")
top_n = st.slider("Top N results", min_value=1, max_value=max(1, len(corpus_sentences)), value=min(5, len(corpus_sentences)))
col_search, col_clear = st.columns([3, 1])
with col_search:
run_search = st.button("Search")
with col_clear:
if st.button("Clear"):
# Reset widget state (keeps cached model/embeddings)
for key in list(st.session_state.keys()):
del st.session_state[key]
if run_search:
query_embedding = emb_model.encode([query], convert_to_tensor=True)
similarities = emb_model.similarity(query_embedding, corpus_embeddings) # (1, len(corpus_sentences))
top_n_indices = torch.topk(similarities, top_n, dim=1).indices[0]
st.write(f"Top {top_n} titres les plus proches de '{query}':")
for i in top_n_indices:
idx = int(i.item())
st.markdown(
f"- **Sentence**: {corpus_sentences[idx]} \n"
f" **Similarity score**: {similarities[0][idx].item():.4f}"
)
# -----------------------------
# Embedding map (PCA 2D) with Plotly
# -----------------------------
try:
# Stack corpus embeddings and query embedding
corpus_np = corpus_embeddings.detach().cpu().numpy()
query_np = query_embedding.detach().cpu().numpy() # (1, d)
all_embeddings = np.vstack([corpus_np, query_np]) # (N+1, d)
pca = PCA(n_components=2)
reduced = pca.fit_transform(all_embeddings) # (N+1, 2)
n_corpus = len(corpus_sentences)
xs = reduced[:, 0]
ys = reduced[:, 1]
top_set = {int(i.item()) for i in top_n_indices}
kinds = []
labels = []
for idx, sent in enumerate(corpus_sentences):
if idx in top_set:
kinds.append("Top N result")
else:
kinds.append("Corpus")
labels.append(f"#{idx} {sent[:10]}{'…' if len(sent) > 10 else ''}")
# Add query point as last entry
kinds.append("Query")
labels.append(f"QUERY: {query[:10]}{'…' if len(query) > 10 else ''}")
df_plot = pd.DataFrame(
{
"x": xs,
"y": ys,
"kind": kinds,
"label": labels,
}
)
color_map = {
"Corpus": "lightgray",
"Top N result": "green",
"Query": "red",
}
fig = px.scatter(
df_plot,
x="x",
y="y",
color="kind",
text="label",
color_discrete_map=color_map,
title="Embedding map (PCA 2D)",
)
fig.update_traces(textposition="top center", marker=dict(size=10))
fig.update_layout(
xaxis_title="PC1",
yaxis_title="PC2",
height=650,
showlegend=True,
)
st.plotly_chart(fig, use_container_width=True)
except Exception as e:
st.warning(f"Could not generate embedding map: {e}")