annafilina commited on
Commit
11cb079
·
1 Parent(s): 0162a45

Update stri.py

Browse files
Files changed (1) hide show
  1. stri.py +28 -85
stri.py CHANGED
@@ -1,87 +1,30 @@
1
  import streamlit as st
2
- import torch
3
- import numpy as np
4
  import pandas as pd
5
- from PIL import Image
6
- from transformers import AutoTokenizer, AutoModel
7
- import re
8
- import pickle
9
- import requests
10
- from io import BytesIO
11
-
12
- st.title("Книжные рекомендации")
13
-
14
- # Загрузка модели и токенизатора
15
- model_name = "symanto/sn-xlm-roberta-base-snli-mnli-anli-xnli"
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
- model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
18
-
19
- # Загрузка датасета и аннотаций к книгам
20
- books = pd.read_csv('all+.csv')
21
- books.dropna(inplace=True)
22
-
23
- books = books[books['annotation'].apply(lambda x: len(x.split()) >= 40)]
24
- books.drop_duplicates(subset='title', keep='first', inplace=True)
25
- books = books.reset_index(drop=True)
26
-
27
-
28
- def data_preprocessing(text: str) -> str:
29
- text = re.sub(r'http\S+', " ", text) # удаляем ссылки
30
- text = re.sub(r'@\w+', ' ', text) # удаляем упоминания пользователей
31
- text = re.sub(r'#\w+', ' ', text) # удаляем хэштеги
32
- text = re.sub(r'<.*?>', ' ', text) # html tags
33
- return text
34
-
35
-
36
- for i in ['author', 'title', 'annotation']:
37
- books[i] = books[i].apply(data_preprocessing)
38
-
39
- annot = books['annotation']
40
-
41
- # Получение эмбеддингов аннотаций каждой книги в датасете
42
- length = 512
43
-
44
- # Определение запроса пользователя
45
- query = st.text_input("Введите запрос")
46
-
47
- if st.button('Сгенерировать'):
48
- with open("book_embeddingsN.pkl", "rb") as f:
49
- book_embeddings = pickle.load(f)
50
-
51
- query_tokens = tokenizer.encode_plus(
52
- query,
53
- add_special_tokens=True,
54
- max_length=length, # Ограничение на максимальную длину входной последовательности
55
- pad_to_max_length=True, # Дополним последовательность нулями до максимальной длины
56
- return_tensors='pt' # Вернём тензоры PyTorch
57
- )
58
-
59
- with torch.no_grad():
60
- query_outputs = model(**query_tokens)
61
- query_hidden_states = query_outputs.hidden_states[-1][:,0,:]
62
- query_hidden_states = torch.nn.functional.normalize(query_hidden_states)
63
-
64
-
65
- # Вычисление косинусного расстояния между эмбеддингом запроса и каждой аннотацией
66
- cosine_similarities = torch.nn.functional.cosine_similarity(
67
- query_embedding.squeeze(0),
68
- torch.stack(book_embeddings.cpu())
69
- )
70
-
71
- cosine_similarities = cosine_similarities.numpy()
72
-
73
- indices = np.argsort(cosine_similarities)[::-1] # Сортировка по убыванию
74
-
75
- num_books_per_page = st.selectbox("Количество книг на странице:", [3, 5, 10], index=0)
76
-
77
- for i in indices[:num_books_per_page]:
78
- cols = st.columns(2) # Создание двух столбцов для размещения информации и изображения
79
- cols[1].write("## " + books['title'][i])
80
- cols[1].markdown("**Автор:** " + books['author'][i])
81
- cols[1].markdown("**Аннотация:** " + books['annotation'][i])
82
- image_url = books['image_url'][i]
83
- response = requests.get(image_url)
84
- image = Image.open(BytesIO(response.content))
85
- cols[0].image(image)
86
- cols[0].write(cosine_similarities[i])
87
- cols[1].write("---")
 
1
  import streamlit as st
 
 
2
  import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+
6
+ # Read the CSV file
7
+ df = pd.read_csv('all+++.csv')
8
+
9
+ # Display the CSV file
10
+ st.title('CSV File Overview')
11
+ st.dataframe(df)
12
+
13
+ # Bar plot for genres
14
+ st.title('Genre Bar Plot')
15
+ genre_counts = df['genre'].value_counts()
16
+ plt.figure(figsize=(10, 6))
17
+ sns.barplot(x=genre_counts.index, y=genre_counts.values)
18
+ plt.xlabel('Genre')
19
+ plt.ylabel('Count')
20
+ plt.xticks(rotation=45)
21
+ st.pyplot()
22
+
23
+ # Distribution plot for annotation lengths
24
+ st.title('Annotation Length Distribution')
25
+ annotation_lengths = df['annotation'].str.len()
26
+ plt.figure(figsize=(10, 6))
27
+ sns.histplot(annotation_lengths, kde=True)
28
+ plt.xlabel('Annotation Length')
29
+ plt.ylabel('Count')
30
+ st.pyplot()