Kdnv commited on
Commit
091b7db
·
1 Parent(s): bd0ee06

day2 push

Browse files
models/index_cosine.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6932fe0946d8e0df89abd2f34ee2d76a6c4a3cb14de59afe7b3974401374ac9
3
+ size 11679789
models/index_dot.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70ea9a98e913f55f9fc700b70fdc7c0ba46fa9fab8e9090940145de0535c6b1b
3
+ size 11679789
models/index_l2.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c1a9ac9bf3594a2ce1749e946a4ac4f7947a310f5fe0a06e619ef73ac018872
3
+ size 11679789
pages/search.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers import SentenceTransformer
3
+ from main import show_serial
4
+ import faiss
5
+ import numpy as np
6
+
7
+
8
+ @st.cache_resource
9
+ def load_model():
10
+ return SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
11
+
12
+
13
+ model = load_model()
14
+
15
+ indices = {
16
+ 'L2': faiss.read_index("models/index_l2.faiss"),
17
+ 'Dot': faiss.read_index("models/index_dot.faiss"),
18
+ 'Cos': faiss.read_index("models/index_cosine.faiss")
19
+ }
20
+
21
+ with st.form(key='submit'):
22
+ cols = st.columns(2)
23
+ with cols[0]:
24
+ prompt = st.text_area('Ваш запрос')
25
+ with cols[1]:
26
+ amount = st.slider(label='Количество рекомендаций', min_value=1, max_value=10, value=5)
27
+ metrics = st.radio(label='Метрика', options=('L2', 'Dot', 'Cos'), horizontal=True)
28
+ submit = st.form_submit_button('Искать', use_container_width=True)
29
+
30
+ if submit:
31
+ prompt_embedding = model.encode(prompt).astype('float32')
32
+ prompt_embedding = prompt_embedding[np.newaxis, :]
33
+
34
+ if metrics == 'Cos':
35
+ prompt_embedding = prompt_embedding / np.linalg.norm(prompt_embedding, axis=1, keepdims=True)
36
+
37
+ distances, indices_result = indices[metrics].search(prompt_embedding, amount)
38
+
39
+ sorted_indices = indices_result[0].tolist()
40
+ sorted_distances = distances[0].tolist()
41
+
42
+ for idx, score in zip(sorted_indices, sorted_distances):
43
+ show_serial(idx, (score, metrics))