Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.decomposition import PCA | |
| import plotly.express as px | |
| from sentence_transformers import SentenceTransformer | |
| def load_model(): | |
| return SentenceTransformer("intfloat/multilingual-e5-small") | |
| emb_model = load_model() | |
| EXAMPLE_SENTENCES = [ | |
| "LiDAR-based domain generalization and unknown 3D object detection", | |
| "Information balance and density in augmented reality assistance", | |
| "Learning Semantics and Geometry for Scene Understanding", | |
| "Integrating Expert Knowledge with Deep Reinforcement Learning Methods for Autonomous Driving", | |
| "Deep Reinforcement Learning and Learning from Demonstrations for Robot Manipulators", | |
| "Smart vehicule trajectory prediction in various autonomous driving scenarios", | |
| "Exploring LiDAR Odometries through Classical, Deep and Inertial perspectives", | |
| "Généralisation de domaine pour la segmentation sémantique de données LiDAR pour le véhicule autonome", | |
| "Localisation en intérieur par SLAM magnéto-visuel-inertiel", | |
| "Learned and Hybrid Strategies for Control and Planning of Highly Automated Vehicles", | |
| ] | |
| st.set_page_config(page_title="Sentence search – multilingual-e5-small", layout="wide") | |
| st.title("🔎 Sentence search with multilingual-e5-small") | |
| st.write("Type a query and find the most similar research titles.") | |
| def compute_corpus_embeddings(_model: SentenceTransformer, sents: list[str]) -> torch.Tensor: | |
| # _model is ignored for caching purposes (only sents is hashed) | |
| return _model.encode(sents, convert_to_tensor=True) | |
| # ----------------------------- | |
| # Corpus configuration (manual text or file upload) | |
| # ----------------------------- | |
| st.subheader("Corpus configuration") | |
| default_corpus_text = "\n".join(EXAMPLE_SENTENCES) | |
| col_left, col_right = st.columns(2) | |
| with col_left: | |
| manual_text = st.text_area( | |
| "Sentences to embed (one per line)", | |
| value=default_corpus_text, | |
| height=220, | |
| ) | |
| with col_right: | |
| uploaded_file = st.file_uploader( | |
| "Or upload a corpus file (CSV, XLSX, JSON)", | |
| type=["csv", "xlsx", "xls", "json"], | |
| ) | |
| df = None | |
| corpus_sentences: list[str] = [] | |
| corpus_source_desc = "" | |
| if uploaded_file is not None: | |
| try: | |
| name = uploaded_file.name.lower() | |
| if name.endswith(".csv"): | |
| df = pd.read_csv(uploaded_file) | |
| elif name.endswith(".xlsx") or name.endswith(".xls"): | |
| df = pd.read_excel(uploaded_file) | |
| elif name.endswith(".json"): | |
| df = pd.read_json(uploaded_file) | |
| except Exception as e: | |
| st.error(f"Failed to read uploaded file: {e}") | |
| if df is not None and not df.empty: | |
| st.markdown("**Uploaded data preview**") | |
| st.dataframe(df.head(100), use_container_width=True) | |
| text_columns = df.select_dtypes(include=["object", "string"]).columns.tolist() | |
| if not text_columns: | |
| st.warning("No text columns detected in the uploaded file. Falling back to manual sentences.") | |
| else: | |
| selected_cols = st.multiselect( | |
| "Columns to concatenate as text (per row)", | |
| text_columns, | |
| default=text_columns[:1], | |
| ) | |
| if selected_cols: | |
| combined = df[selected_cols].astype(str).fillna("") | |
| corpus_sentences = ( | |
| combined.apply( | |
| lambda row: " ".join(v.strip() for v in row if isinstance(v, str) and v.strip()), | |
| axis=1, | |
| ) | |
| .tolist() | |
| ) | |
| corpus_sentences = [s for s in corpus_sentences if s] | |
| corpus_source_desc = f"Uploaded file – columns: {', '.join(selected_cols)} (rows: {len(corpus_sentences)})" | |
| else: | |
| st.info("Select at least one column to build the corpus. Using manual sentences instead.") | |
| if not corpus_sentences: | |
| corpus_sentences = [s.strip() for s in manual_text.splitlines() if s.strip()] | |
| corpus_source_desc = f"Manual sentences (count: {len(corpus_sentences)})" | |
| if not corpus_sentences: | |
| st.warning("No sentences available to embed. Please add sentences or upload a file.") | |
| else: | |
| st.caption(corpus_source_desc) | |
| corpus_embeddings = compute_corpus_embeddings(emb_model, corpus_sentences) | |
| query = st.text_input("Query", value="autonomous vehicle and image") | |
| top_n = st.slider("Top N results", min_value=1, max_value=max(1, len(corpus_sentences)), value=min(5, len(corpus_sentences))) | |
| col_search, col_clear = st.columns([3, 1]) | |
| with col_search: | |
| run_search = st.button("Search") | |
| with col_clear: | |
| if st.button("Clear"): | |
| # Reset widget state (keeps cached model/embeddings) | |
| for key in list(st.session_state.keys()): | |
| del st.session_state[key] | |
| if run_search: | |
| query_embedding = emb_model.encode([query], convert_to_tensor=True) | |
| similarities = emb_model.similarity(query_embedding, corpus_embeddings) # (1, len(corpus_sentences)) | |
| top_n_indices = torch.topk(similarities, top_n, dim=1).indices[0] | |
| st.write(f"Top {top_n} titres les plus proches de '{query}':") | |
| for i in top_n_indices: | |
| idx = int(i.item()) | |
| st.markdown( | |
| f"- **Sentence**: {corpus_sentences[idx]} \n" | |
| f" **Similarity score**: {similarities[0][idx].item():.4f}" | |
| ) | |
| # ----------------------------- | |
| # Embedding map (PCA 2D) with Plotly | |
| # ----------------------------- | |
| try: | |
| # Stack corpus embeddings and query embedding | |
| corpus_np = corpus_embeddings.detach().cpu().numpy() | |
| query_np = query_embedding.detach().cpu().numpy() # (1, d) | |
| all_embeddings = np.vstack([corpus_np, query_np]) # (N+1, d) | |
| pca = PCA(n_components=2) | |
| reduced = pca.fit_transform(all_embeddings) # (N+1, 2) | |
| n_corpus = len(corpus_sentences) | |
| xs = reduced[:, 0] | |
| ys = reduced[:, 1] | |
| top_set = {int(i.item()) for i in top_n_indices} | |
| kinds = [] | |
| labels = [] | |
| for idx, sent in enumerate(corpus_sentences): | |
| if idx in top_set: | |
| kinds.append("Top N result") | |
| else: | |
| kinds.append("Corpus") | |
| labels.append(f"#{idx} {sent[:10]}{'…' if len(sent) > 10 else ''}") | |
| # Add query point as last entry | |
| kinds.append("Query") | |
| labels.append(f"QUERY: {query[:10]}{'…' if len(query) > 10 else ''}") | |
| df_plot = pd.DataFrame( | |
| { | |
| "x": xs, | |
| "y": ys, | |
| "kind": kinds, | |
| "label": labels, | |
| } | |
| ) | |
| color_map = { | |
| "Corpus": "lightgray", | |
| "Top N result": "green", | |
| "Query": "red", | |
| } | |
| fig = px.scatter( | |
| df_plot, | |
| x="x", | |
| y="y", | |
| color="kind", | |
| text="label", | |
| color_discrete_map=color_map, | |
| title="Embedding map (PCA 2D)", | |
| ) | |
| fig.update_traces(textposition="top center", marker=dict(size=10)) | |
| fig.update_layout( | |
| xaxis_title="PC1", | |
| yaxis_title="PC2", | |
| height=650, | |
| showlegend=True, | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception as e: | |
| st.warning(f"Could not generate embedding map: {e}") |