Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| from torch import nn | |
| # Загрузка модели и токенизатора (кешируем для ускорения) | |
| def load_model(): | |
| MODEL_NAME = "cointegrated/rubert-tiny2" | |
| model = AutoModel.from_pretrained(MODEL_NAME, num_labels=5) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| return model, tokenizer | |
| PATH = "models/model_weight_bert.pt" | |
| class MyTinyBERT(nn.Module): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.bert = model | |
| for param in self.bert.parameters(): | |
| param.requires_grad = False | |
| self.linear = nn.Sequential( | |
| nn.Linear(312, 256), nn.Dropout(0.3), nn.ReLU(), nn.Linear(256, 5) | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| normed_bert_out = bert_out.last_hidden_state[:, 0, :] | |
| out = self.linear(normed_bert_out) | |
| return out | |
| def classification_myBERT(text, model, tokenizer): | |
| model = MyTinyBERT(model) | |
| model.load_state_dict(torch.load(PATH, weights_only=True)) | |
| model.eval() | |
| my_classes = {0: "Крипта", 1: "Мода", 2: "Спорт", 3: "Технологии", 4: "Финансы"} | |
| t = tokenizer(text, padding=True, truncation=True, return_tensors="pt") | |
| return f'Хоть я и не ChatGPT, осмелюсь предположить, что данный текст относится к следующему классу:\n{my_classes[torch.argmax(model(t["input_ids"], t["attention_mask"])).item()]}' | |
| # Интерфейс Streamlit | |
| def main(): | |
| st.markdown( | |
| "<h1 style='text-align: center;'>Классификация тематики новостей из телеграм каналов.</h1>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("---") | |
| col1, col2, col3 = st.columns([1, 8, 1]) # Центральная колонка шире остальных | |
| with col2: | |
| st.markdown( | |
| "<h5 style='text-align: center;'>Использование классического алгоритма</h5>", | |
| unsafe_allow_html=True, | |
| ) | |
| # st.text("Использование классического алгоритма") | |
| st.image("./images/Struct.png", width=500) | |
| st.image("./images/L_A.png", width=800) | |
| st.image("./images/C_M.png", width=800) | |
| st.markdown( | |
| "<h5 style='text-align: center;'>Стандартный rubert_tiny2</h5>", | |
| unsafe_allow_html=True, | |
| ) | |
| # st.text("Использование классического алгоритма") | |
| st.image("./images/LogReg.png", width=800) | |
| st.markdown( | |
| "<h5 style='text-align: center;'>rubert_tiny2 с обучаемым fc слоем</h5>", | |
| unsafe_allow_html=True, | |
| ) | |
| # st.text("Использование классического алгоритма") | |
| st.image("./images/myTinyBERT.png", width=800) | |
| # Загрузка модели | |
| model, tokenizer = load_model() | |
| # Параметры генерации | |
| with st.sidebar: | |
| st.header("Настройки генерации") | |
| prompt = st.text_area("Введите начальный текст:", height=100) | |
| # Кнопка генерации | |
| if st.sidebar.button("Сгенерировать текст"): | |
| if not prompt: | |
| st.warning("Введите начальный текст!") | |
| return | |
| st.subheader("Результаты:") | |
| st.text(classification_myBERT(prompt, model, tokenizer)) | |
| if __name__ == "__main__": | |
| main() | |