samiha123 commited on
Commit
07c3ebf
·
1 Parent(s): 3e8c377

first commit

Browse files
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
1
+ streamlit==1.41.1
2
+ scipy
3
+ pystemmer
4
+ scikit-learn
5
+ bm25s
6
+ transformers
7
+ torch
src/data_final_cleaned.json ADDED
The diff for this file is too large to render. See raw diff
 
src/embedding_function.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides functionality to embed texts using the Hugging Face API.
3
+ It includes an EmbeddingFunction class for asynchronous embedding and a sync_embed function for synchronous embedding.
4
+ """
5
+ from huggingface_hub import InferenceClient
6
+ import asyncio
7
+ import os
8
+ from typing import List, Optional, Union
9
+ import os
10
+ from huggingface_hub import InferenceClient
11
+
12
+ import httpx
13
+ from dotenv import load_dotenv
14
+
15
+ load_dotenv()
16
+
17
+ TextType = Union[str, List[str]]
18
+
19
+
20
+ class EmbeddingFunction:
21
+ """
22
+ A class to handle embedding functions using the Hugging Face API.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ model: str,
28
+ api_key: Optional[str] = None,
29
+ batch_size: int = 50,
30
+ api_url: Optional[str] = None,
31
+ ):
32
+ """
33
+ Initialize the EmbeddingFunction.
34
+
35
+ Args:
36
+ model (str): The model to use for embedding.
37
+ api_key (Optional[str]): The API key for the Hugging Face API. If not provided,
38
+ it will be fetched from the environment variable `HF_API_KEY`.
39
+ batch_size (int): The number of texts to process in a single batch. Default is 50.
40
+ api_url (Optional[str]): Custom API URL for Hugging Face inference endpoint.
41
+ """
42
+
43
+
44
+ def sync_embed(texts: str, model: str, api_key: str) -> list:
45
+ """
46
+ Extrait les embeddings d'un texte via l'API Inference de Hugging Face.
47
+
48
+ Args:
49
+ texts (str): Le texte à encoder.
50
+ model (str): Le modèle Hugging Face à utiliser.
51
+ api_key (str): La clé API Hugging Face.
52
+
53
+ Returns:
54
+ list: Les embeddings du texte.
55
+ """
56
+ client = InferenceClient(provider="hf-inference", api_key=api_key)
57
+ result = client.feature_extraction(texts, model=model)
58
+ return result[0] # Retourne le premier embedding
src/embeddings_cache/all-MiniLM-L6-v2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8df55f6acd3449885335d0199fa46ea1f243d627370c0d741c71f96ac9ee9a05
3
+ size 1875998
src/embeddings_cache/distiluse-base-multilingual-cased-v2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57f452770dd3ba3527bbd587465b0ea6ae6d29c27ecc5278a5b56c1d7adad52c
3
+ size 3485726
src/embeddings_cache/e5-small-v2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ed0114d23bdee9ad0fc620f5f593e4ca20c43173539c8f9bd2f3cc9b807f8da
3
+ size 4439841
src/embeddings_cache/multi-qa-MiniLM-L6-cos-v1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7aebb55ef48daaf077e1ccca89345ea530b8eeccb3d3b1216388c93025a10456
3
+ size 4439841
src/embeddings_cache/multilingual-e5-large.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5871f6ba37ea5bae28531c12169148c36cf7ecd414a4af5aadbc01ca77890c3
3
+ size 2434314
src/embeddings_cache/multilingual-e5-small.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a42cf13b938bd83500736c2644758950a0b2ad5aadc2b11d5c7b719e319eead1
3
+ size 1875998
src/embeddings_cache/paraphrase-mpnet-base-v2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78c5816feafe64571efe8a43cd3db4c7c12f4f040e596c8eb09af7b60f58429e
3
+ size 1897738
src/retrieval.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from retrievals import TFIDFRetriever
3
+ import pprint
4
+ from retrievals import BM25Retriever
5
+ from typing import Callable, List
6
+ import numpy as np
7
+ from typing import Callable
8
+ import bm25s
9
+ import numpy as np
10
+ import Stemmer
11
+ from scipy.spatial.distance import cdist
12
+ from sklearn.feature_extraction.text import TfidfVectorizer
13
+ import asyncio
14
+ from typing import List, Union, Optional
15
+ from transformers import AutoTokenizer, AutoModel
16
+ import torch
17
+ import os
18
+ from typing import List, Optional, Union
19
+ import requests
20
+ import numpy as np
21
+ from typing import Callable, List
22
+ from scipy.spatial.distance import cdist
23
+ from retrieval_evaluation.src.embedding_function import sync_embed
24
+
25
+
26
+
27
+ ####################################################################################
28
+
29
+ with open("data_final_cleaned.json", "r", encoding="utf-8") as f:
30
+ raw_data = json.load(f)
31
+
32
+ formatted_data = []
33
+ for item in raw_data:
34
+ if "docs" in item:
35
+ metadata_value = item["docs"].get("metadata", "")
36
+ content_value = item["docs"].get("content", "")
37
+ formatted_data.append({
38
+ "cleaned_content": content_value,
39
+ "metadata": {"source": metadata_value}
40
+
41
+ })
42
+
43
+
44
+ ######################################"TF_IDF########################################
45
+
46
+ def get_retrieval_tf_idf(query):
47
+
48
+ tfidf_retriever = TFIDFRetriever()
49
+ tfidf_retriever.index_data(formatted_data)
50
+
51
+ results = tfidf_retriever.search(query, k=3)
52
+
53
+ formatted_results = {
54
+ 'json': {
55
+ 'question': query,
56
+ 'results': []
57
+ }
58
+ }
59
+
60
+ for result in results:
61
+ formatted_results['json']['results'].append({
62
+ 'content': result['text'],
63
+ 'metadata': result['source'],
64
+ 'score': float(result['score'])
65
+ })
66
+
67
+ return formatted_results
68
+
69
+
70
+ ##################################BM25##########################################
71
+
72
+
73
+ def get_retrieval_bm25(query):
74
+ bm25_retriever = BM25Retriever()
75
+ bm25_retriever.index_data(formatted_data)
76
+
77
+ results = bm25_retriever.search(query, k=3)
78
+
79
+ formatted_results = {
80
+ 'json': {
81
+ 'question': query,
82
+ 'results': []
83
+ }
84
+ }
85
+
86
+ for result in results:
87
+ formatted_results['json']['results'].append({
88
+ 'content': result['text'],
89
+ 'metadata': result['source'],
90
+ 'score': float(result['score'])
91
+ })
92
+
93
+ return formatted_results
94
+
95
+
96
+ #######################################dense retrieval###################################
97
+ import numpy as np
98
+ from typing import Callable, List
99
+ from scipy.spatial.distance import cdist
100
+ import pickle
101
+ import os
102
+
103
+ class DenseRetriever:
104
+ """
105
+ A retriever model that uses dense embeddings for indexing and searching documents.
106
+
107
+ Attributes:
108
+ vectorizer (Callable): The function used to generate embeddings.
109
+ index (np.ndarray): The indexed embeddings.
110
+ data (list): The data to be indexed.
111
+ """
112
+
113
+ def __init__(self, vectorizer: Callable):
114
+ """
115
+ Initialize the DenseRetriever.
116
+
117
+ Args:
118
+ vectorizer (Callable): The function to generate embeddings.
119
+ """
120
+ self.vectorizer = vectorizer
121
+ self.index = None
122
+ self.data = None
123
+
124
+ def load_index(self, filepath: str):
125
+ """
126
+ Load the index and metadata from a pickle file.
127
+
128
+ Args:
129
+ filepath (str): Path to the .pkl file containing 'index' and 'data'.
130
+ """
131
+ with open(filepath, 'rb') as f:
132
+ saved = pickle.load(f)
133
+ self.index = saved['index']
134
+ self.data = saved['data']
135
+
136
+ def index_data(self, data: List[dict]):
137
+ """
138
+ Indexes the provided data using dense embeddings.
139
+
140
+ Args:
141
+ data (list): A list of documents to be indexed. Each document should be a dictionary
142
+ containing a key 'cleaned_content' with the text to be indexed.
143
+ """
144
+ self.data = data
145
+ docs = [doc["cleaned_content"] for doc in data]
146
+ embeddings = self.vectorizer(docs)
147
+ self.index = np.array(embeddings)
148
+
149
+ def search(self, query: str, k: int = 5) -> List[dict]:
150
+ """
151
+ Searches the indexed data for the given query using cosine similarity.
152
+
153
+ Args:
154
+ query (str): The search query.
155
+ k (int): The number of top results to return.
156
+
157
+ Returns:
158
+ list: A list of dictionaries containing the source, text, and score of the top-k results.
159
+ """
160
+ query_embedding = self.vectorizer([query]) # Doit retourner une liste ou np.ndarray
161
+
162
+ # Vérification du résultat
163
+ if query_embedding is None:
164
+ raise ValueError("La fonction vectorizer a retourné None.")
165
+
166
+ query_embedding = np.array(query_embedding)
167
+
168
+ if query_embedding.ndim == 1:
169
+ query_embedding = query_embedding[np.newaxis, :] # le transformer en (1, dim)
170
+
171
+ if query_embedding.ndim != 2:
172
+ raise ValueError("query_embedding doit être un tableau 2D.")
173
+
174
+ if self.index.ndim != 2:
175
+ raise ValueError("L'index dense doit être un tableau 2D.")
176
+
177
+ if self.index.shape[1] != query_embedding.shape[1]:
178
+ raise ValueError(f"Dimensions incompatibles entre query ({query_embedding.shape[1]}) et index ({self.index.shape[1]}).")
179
+
180
+
181
+ cosine_distances = cdist(query_embedding, self.index, metric="cosine")[0]
182
+
183
+ top_k_indices = cosine_distances.argsort()[:k]
184
+ output = []
185
+ for idx in top_k_indices:
186
+ output.append(
187
+ {
188
+ "source": self.data[idx]["metadata"]["source"],
189
+ "text": self.data[idx]["cleaned_content"],
190
+ "score": 1 - cosine_distances[idx],
191
+ }
192
+ )
193
+ return output
194
+
195
+ def predict(self, query: str, k: int) -> List[dict]:
196
+ return self.search(query, k)
197
+
198
+ import os
199
+ import pickle
200
+ def get_retrieval_dense(query, model=None, api_key=None):
201
+ if model is None:
202
+ raise ValueError("Model must be specified")
203
+
204
+ if isinstance(model, list):
205
+ model = model[0] # Sécurisation
206
+
207
+ model_filename = model.split("/")[-1] + ".pkl"
208
+ index_path = os.path.join("embeddings_cache", model_filename)
209
+
210
+ if not os.path.exists(index_path):
211
+ raise FileNotFoundError(f"L'index pour le modèle {model} est introuvable à l'emplacement : {index_path}")
212
+
213
+
214
+
215
+
216
+ with open(index_path, "rb") as f:
217
+ saved = pickle.load(f)
218
+
219
+
220
+ dr = DenseRetriever(vectorizer=lambda docs: sync_embed(texts=docs, model=f"{model}", api_key=os.getenv("HF_API_KEY")))
221
+
222
+ # Attribuer les valeurs du dictionnaire à l'instance
223
+ dr.index = saved["index"]
224
+ dr.data = saved["data"]
225
+
226
+ # Exécuter la recherche
227
+ results = dr.search(query, k=3)
228
+
229
+
230
+
231
+
232
+
233
+ formatted_results = {
234
+ 'json': {
235
+ 'question': query,
236
+ 'results': []
237
+ }
238
+ }
239
+
240
+ for result in results:
241
+ formatted_results['json']['results'].append({
242
+ 'content': result['text'],
243
+ 'metadata': result['source'],
244
+ 'score': float(result['score'])
245
+ })
246
+
247
+ return formatted_results
src/retrievals.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains implementations of various retriever models for document retrieval.
3
+ """
4
+ from typing import Callable
5
+ import bm25s
6
+ import numpy as np
7
+ import Stemmer
8
+ from scipy.spatial.distance import cdist
9
+ from sklearn.feature_extraction.text import TfidfVectorizer
10
+ import asyncio
11
+ from typing import List, Union, Optional
12
+ from transformers import AutoTokenizer, AutoModel
13
+ import torch
14
+ import os
15
+ from typing import List, Optional, Union
16
+ import requests
17
+ import numpy as np
18
+ from typing import Callable, List
19
+ from scipy.spatial.distance import cdist
20
+
21
+
22
+
23
+ class TFIDFRetriever:
24
+ """
25
+ A retriever model that uses TF-IDF for indexing and searching documents.
26
+
27
+ Attributes:
28
+ vectorizer (TfidfVectorizer): The TF-IDF vectorizer.
29
+ index (scipy.sparse matrix): The indexed TF-IDF vectors.
30
+ data (list): The original data used for indexing.
31
+ """
32
+
33
+ def __init__(self):
34
+ self.vectorizer = TfidfVectorizer()
35
+ self.index = None
36
+ self.data = None
37
+ self.stemmer = Stemmer.Stemmer("english")
38
+
39
+ def index_data(self, data):
40
+ """
41
+ Indexes the provided data using TF-IDF.
42
+
43
+ Args:
44
+ data (list): A list of documents to be indexed. Each document should be a dictionary
45
+ containing a key 'cleaned_content' with the text to be indexed.
46
+ """
47
+ self.data = data
48
+ docs = [doc["cleaned_content"] for doc in data]
49
+ self.index = self.vectorizer.fit_transform(docs)
50
+
51
+ def search(self, query, k=5):
52
+ """
53
+ Searches the indexed data for the given query using cosine similarity.
54
+
55
+ Args:
56
+ query (str): The search query.
57
+ k (int): The number of top results to return. Default is 5.
58
+
59
+ Returns:
60
+ list: A list of dictionaries containing the source, text, and score of the top-k results.
61
+ """
62
+ query_vec = self.vectorizer.transform([query])
63
+ cosine_distances = cdist(
64
+ query_vec.todense(), self.index.todense(), metric="cosine"
65
+ )[0]
66
+ top_k_indices = cosine_distances.argsort()[:k]
67
+ output = []
68
+ for idx in top_k_indices:
69
+ output.append(
70
+ {
71
+ "source": self.data[idx]["metadata"]["source"],
72
+ "text": self.data[idx]["cleaned_content"],
73
+ "score": 1 - cosine_distances[idx],
74
+ }
75
+ )
76
+ return output
77
+
78
+ def predict(self, query: str, k: int):
79
+ """
80
+ Predicts the top-k results for the given query.
81
+
82
+ Args:
83
+ query (str): The search query.
84
+ k (int): The number of top results to return.
85
+
86
+ Returns:
87
+ list: A list of dictionaries containing the source, text, and score of the top-k results.
88
+ """
89
+ return self.search(query, k)
90
+
91
+
92
+
93
+ ################################################BM25##########################################
94
+
95
+
96
+ class BM25Retriever:
97
+ """
98
+ A retriever model that uses BM25 for indexing and searching documents.
99
+
100
+ Attributes:
101
+ index (bm25s.BM25): The BM25 index.
102
+ data (list): The data to be indexed.
103
+ """
104
+
105
+ def __init__(self):
106
+ self.index = bm25s.BM25()
107
+ self.data = None
108
+
109
+ def index_data(self, data):
110
+ """
111
+ Indexes the provided data using BM25.
112
+
113
+ Args:
114
+ data (list): A list of documents to be indexed. Each document should be a dictionary
115
+ containing a key 'cleaned_content' with the text to be indexed.
116
+ """
117
+ self.data = data
118
+ corpus = [doc["cleaned_content"] for doc in data]
119
+
120
+ corpus_tokens = bm25s.tokenize(corpus, show_progress=False)
121
+
122
+ self.index.index(corpus_tokens, show_progress=False)
123
+
124
+
125
+ def search(self, query, k=5):
126
+ """
127
+ Searches the indexed data for the given query using BM25.
128
+
129
+ Args:
130
+ query (str): The search query.
131
+ k (int): The number of top results to return. Default is 5.
132
+
133
+ Returns:
134
+ list: A list of dictionaries containing the source, text, and score of the top-k results.
135
+ """
136
+ query_tokens = bm25s.tokenize(query, show_progress=False)
137
+ # Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k)
138
+ results, scores = self.index.retrieve(
139
+ query_tokens, corpus=self.data, k=k, show_progress=False
140
+ )
141
+
142
+ output = []
143
+ for idx in range(results.shape[1]):
144
+ output.append(
145
+ {
146
+ "source": results[0, idx]["metadata"]["source"],
147
+ "text": results[0, idx]["cleaned_content"],
148
+ "score": scores[0, idx],
149
+ }
150
+ )
151
+ return output
152
+
153
+ def predict(self, query: str, k: int):
154
+ """
155
+ Predicts the top-k results for the given query.
156
+
157
+ Args:
158
+ query (str): The search query.
159
+ k (int): The number of top results to return.
160
+
161
+ Returns:
162
+ list: A list of dictionaries containing the source, text, and score of the top-k results.
163
+ """
164
+ return self.search(query, k)
165
+
166
+
167
+
168
+ ###########################################EMBEDDINGS##########################################
169
+
170
+
171
+ class DenseRetriever:
172
+ """
173
+ A retriever model that uses dense embeddings for indexing and searching documents.
174
+
175
+ Attributes:
176
+ vectorizer (Callable): The function used to generate embeddings.
177
+ index (np.ndarray): The indexed embeddings.
178
+ data (list): The data to be indexed.
179
+ """
180
+
181
+ def __init__(self, vectorizer: Callable, batch_size: int = 50):
182
+ """
183
+ Initialize the DenseRetriever.
184
+
185
+ Args:
186
+ vectorizer (Callable): The function to generate embeddings.
187
+ batch_size (int): The number of texts to process in a single batch. Default is 50.
188
+ """
189
+ self.vectorizer = vectorizer
190
+ self.batch_size = batch_size
191
+ self.index = None
192
+ self.data = None
193
+
194
+ def index_data(self, data: List[dict]):
195
+ """
196
+ Indexes the provided data using dense embeddings.
197
+
198
+ Args:
199
+ data (list): A list of documents to be indexed. Each document should be a dictionary
200
+ containing a key 'cleaned_content' with the text to be indexed.
201
+ """
202
+ self.data = data
203
+ docs = [doc["cleaned_content"] for doc in data]
204
+ embeddings = self.vectorizer(docs)
205
+ self.index = np.array(embeddings)
206
+
207
+ def search(self, query: str, k: int = 5) -> List[dict]:
208
+ """
209
+ Searches the indexed data for the given query using cosine similarity.
210
+
211
+ Args:
212
+ query (str): The search query.
213
+ k (int): The number of top results to return. Default is 5.
214
+
215
+ Returns:
216
+ list: A list of dictionaries containing the source, text, and score of the top-k results.
217
+ """
218
+ query_embedding = self.vectorizer([query])
219
+ cosine_distances = cdist(query_embedding, self.index, metric="cosine")[0]
220
+ top_k_indices = cosine_distances.argsort()[:k]
221
+ output = []
222
+ for idx in top_k_indices:
223
+ output.append(
224
+ {
225
+ "source": self.data[idx]["metadata"]["source"],
226
+ "text": self.data[idx]["cleaned_content"],
227
+ "score": 1 - cosine_distances[idx],
228
+ }
229
+ )
230
+ return output
231
+
232
+ def predict(self, query: str, k: int) -> List[dict]:
233
+ """
234
+ Predicts the top-k results for the given query.
235
+
236
+ Args:
237
+ query (str): The search query.
238
+ k (int): The number of top results to return.
239
+
240
+ Returns:
241
+ list: A list of dictionaries containing the source, text, and score of the top-k results.
242
+ """
243
+ return self.search(query, k)
src/streamlit_app.py CHANGED
@@ -1,40 +1,117 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import json
3
+ from retrievals import TFIDFRetriever, BM25Retriever
4
+ from retrieval import get_retrieval_tf_idf, get_retrieval_bm25, get_retrieval_dense
5
+ from embedding_function import sync_embed
6
+ import numpy as np
7
+ import os
8
+
9
+ st.set_page_config(
10
+ page_title="Vector Store Query App",
11
+ layout="wide",
12
+ initial_sidebar_state="expanded"
13
+ )
14
+
15
+ st.markdown("""
16
+ <style>
17
+ .block-container {
18
+ padding-top: 2rem;
19
+ padding-bottom: 2rem;
20
+ }
21
+ .section {
22
+ padding: 1rem;
23
+ border-radius: 0.5rem;
24
+ margin-bottom: 1rem;
25
+ }
26
+ </style>
27
+ """, unsafe_allow_html=True)
28
+
29
+ with st.sidebar:
30
+ st.title("About")
31
+ st.markdown("""
32
+ This app allows you to query a vector store and view results in both JSON format
33
+ and rendered markdown. Enter your question in the main panel and click 'Search'.
34
+ """)
35
+
36
+ retrieval_method = st.selectbox(
37
+ "Choose the retrieval method:",
38
+ ["Sparse Retrievals", "Dense Retrievals", "Hybrid Retrievals"]
39
+ )
40
+
41
+ if retrieval_method == "Sparse Retrievals":
42
+ sparse_method = st.selectbox(
43
+ "Choose a Sparse Retrieval method:",
44
+ ["BM25", "TF-IDF"]
45
+ )
46
+ st.write(f"Selected Sparse Method: {sparse_method}")
47
+
48
+ elif retrieval_method == "Dense Retrievals":
49
+ model_selection = st.selectbox(
50
+ "Choose a model:",
51
+ [
52
+
53
+ "sentence-transformers/all-MiniLM-L6-v2",
54
+ "intfloat/multilingual-e5-large"
55
+
56
+
57
+ ]
58
+ )
59
+ st.write(f"Selected model: {model_selection}")
60
+ st.session_state.model_selection = model_selection
61
+
62
+ st.title("Vector Store Query Interface")
63
+
64
+ if 'results' not in st.session_state:
65
+ st.session_state.results = None
66
+
67
+ with st.form("query_form"):
68
+ col1, col2 = st.columns([4, 1])
69
+ with col1:
70
+ query = st.text_input(
71
+ "Enter your question:",
72
+ placeholder="What are you looking for?",
73
+ label_visibility="collapsed"
74
+ )
75
+ with col2:
76
+ st.write("")
77
+ if st.form_submit_button("Search", use_container_width=True):
78
+ if query:
79
+ # Dense Retrieval with selected model
80
+ if retrieval_method == "Dense Retrievals":
81
+ model_selection = st.session_state.get('model_selection')
82
+ api_key = os.getenv("HF_API_KEY")
83
+ embeddings = sync_embed(texts=query, model=model_selection, api_key=api_key)
84
+ st.session_state.results = get_retrieval_dense(query, model=model_selection, api_key=api_key)
85
+
86
+ elif retrieval_method == "Sparse Retrievals" and sparse_method == "TF-IDF":
87
+ st.session_state.results = get_retrieval_tf_idf(query)
88
+ elif retrieval_method == "Sparse Retrievals" and sparse_method == "BM25":
89
+ st.session_state.results = get_retrieval_bm25(query)
90
+ else:
91
+ st.warning("Please enter a question")
92
+
93
+ if st.session_state.results:
94
+ st.divider()
95
+ st.subheader("Results")
96
+
97
+ col_left, col_right = st.columns([1, 2], gap="large")
98
+
99
+ with col_left:
100
+ st.markdown("**JSON Output**")
101
+ st.code(
102
+ json.dumps(st.session_state.results['json'], indent=2),
103
+ language='json'
104
+ )
105
+
106
+ with col_right:
107
+ st.markdown("**Document Content**")
108
+
109
+ for i, doc in enumerate(st.session_state.results['json']['results']):
110
+ with st.container():
111
+ st.markdown(f"### Document {i+1}")
112
+ st.markdown(doc['content'])
113
+ st.markdown(f"**Source:** {doc['metadata']}")
114
+ st.divider()
115
 
116
+ elif st.session_state.results is None:
117
+ st.info("👈 Enter a question and click Search to get started")