File size: 3,840 Bytes
8df2b21
 
 
cff9354
 
87e18be
8df2b21
 
 
 
87e18be
23221da
cff9354
 
 
 
 
 
 
 
87e18be
cff9354
 
 
 
 
 
 
23221da
ccd93f0
23221da
 
 
 
 
 
 
 
 
ccd93f0
7de79a5
cff9354
 
 
 
 
 
 
 
 
 
 
 
 
 
23221da
 
ccd93f0
8df2b21
 
cff9354
8df2b21
ccd93f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8df2b21
 
 
cff9354
 
23221da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import pandas as pd
import streamlit as st
import random
from langchain.chat_models.gigachat import GigaChat
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
import os

if 'button_clicked' not in st.session_state:
    st.session_state.button_clicked = False

api_key = os.getenv("api_chat")

@st.cache_data
def load_data():
    return pd.read_csv('data/data.csv', index_col=0)


@st.cache_resource
def load_chat():
    chat = GigaChat(
        credentials=api_key,
        verify_ssl_certs=False)
    return chat


chat = load_chat()


def show_serial(number, score=None):
    with st.container(border=True):
        col1, col2, col3 = st.columns(3)
        with col1:
            st.image(data.iloc[number, 1])
        with col2:
            st.subheader(data.iloc[number, 2])
            st.metric(label='IMDB', value=data.iloc[number, 5])
            st.caption(data.iloc[number, 3])
            if score:
                st.write(f'{score[1]} metric: {score[0]:.4f}')
            st.markdown(f'[Ссылка]({data.iloc[number, 0]})')
        with col3:
            tab1, tab2 = st.tabs(["Аннотация", "Описание от бота"])
            with tab1:
                st.text_area(label='Аннотация', value=data.iloc[number, 4],
                             height=250, disabled=True, label_visibility='hidden')
            with tab2:
                setting = "ты умеешь кратко в несколько предложений описывать содержание книги по ее названию"
                system_message_prompt = SystemMessagePromptTemplate.from_template(setting)
                human_template = "Кратко опиши cюжет сериала под названием: {title}"
                human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
                chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
                formatted_prompt = chat_prompt.format_prompt(title=data.iloc[number, 2])
                response = chat(formatted_prompt.to_messages())
                st.text_area(label='Чат', value=response.content,
                             height=250, disabled=True, label_visibility='hidden')


st.title('Рекомендатор сериалов')
st.divider()

data = load_data()

cols = st.columns(4)
with cols[0]:
    st.markdown('<p style="text-align: center; font-weight: bold;">Количество сериалов</p>', unsafe_allow_html=True)
    st.markdown('<p style="text-align: center;">5000</p>', unsafe_allow_html=True)

with cols[1]:
    st.markdown('<p style="text-align: center; font-weight: bold;">Источник парсинга</p>', unsafe_allow_html=True)
    st.markdown('<p style="text-align: center;"><a href="https://www.film.ru/a-z/serials" target="_blank">Перейти на сайт</a></p>', unsafe_allow_html=True)

with cols[2]:
    st.markdown('<p style="text-align: center; font-weight: bold;">Время парсинга</p>', unsafe_allow_html=True)
    st.markdown('<p style="text-align: center;">27 минут</p>', unsafe_allow_html=True)

with cols[3]:
    st.markdown('<p style="text-align: center; font-weight: bold;">Модель</p>', unsafe_allow_html=True)
    st.markdown('<p style="text-align: center;"><a href="https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2" target="_blank">multilingual-mpnet-base-v2</a></p>', unsafe_allow_html=True)

if st.button('Дай 10 случайных сериалов', use_container_width=True):
    st.session_state.button_clicked = True

if st.session_state.button_clicked:
    indices = random.sample(range(data.shape[0]), 10)
    for number in indices:
        show_serial(number)