Spaces:
Sleeping
Sleeping
| import time | |
| import joblib | |
| import re | |
| import string | |
| import pymorphy3 | |
| import torch | |
| from transformers import BertModel, BertTokenizer | |
| from torch import nn | |
| model_name = "cointegrated/rubert-tiny2" | |
| tokenizer = BertTokenizer.from_pretrained(model_name) | |
| bert_model = BertModel.from_pretrained(model_name) | |
| class MyTinyBERT(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.bert = bert_model | |
| for param in self.bert.parameters(): | |
| param.requires_grad = False | |
| self.linear = nn.Sequential( | |
| nn.Linear(312, 256), | |
| nn.Sigmoid(), | |
| nn.Dropout(), | |
| nn.Linear(256, 6) | |
| ) | |
| def forward(self, input_ids, attention_mask=None): | |
| # Pass the input_ids and attention_mask to the BERT model | |
| bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| # Normalize the output from BERT | |
| normed_bert_out = nn.functional.normalize(bert_out.last_hidden_state[:, 0, :]) | |
| # Pass through the linear layer | |
| out = self.linear(normed_bert_out) | |
| return out | |
| weights_path = "models/clf_rewievs_bert.pt" | |
| model = MyTinyBERT() | |
| model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) | |
| model.to('cpu') | |
| # tokenizer = transformers.AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") | |
| # bert_model = transformers.AutoModel.from_pretrained("cointegrated/rubert-tiny2") | |
| # weights_path = "./model_weights.pt" # Replace with your .pt file path | |
| # bert_model.load_state_dict(torch.load('models/clf_rewievs_bert.pt', map_location=torch.device('cpu'))) | |
| # bert_model.to('cpu') | |
| morph = pymorphy3.MorphAnalyzer() | |
| def lemmatize(text): | |
| words = text.split() | |
| lem_words = [morph.parse(word)[0].normal_form for word in words] | |
| return " ".join(lem_words) | |
| logreg = joblib.load('models/logregmodel_restaurants.pkl') | |
| vectorizer = joblib.load('models/tfidf_vectorizer_restaurants.pkl') | |
| with open( | |
| "funcs/stopwords-ru.txt", "r", encoding="utf-8" | |
| ) as file: | |
| stop_words = set(file.read().split()) | |
| rating_dict = { | |
| 1: "Отвратительно", | |
| 2: "Плохо", | |
| 3: "Удовлетворительно", | |
| 4: "Хорошо", | |
| 5: "Великолепно",} | |
| emoji_pattern = re.compile( | |
| "[" | |
| "\U0001F600-\U0001F64F" # Emoticons | |
| "\U0001F300-\U0001F5FF" # Symbols & Pictographs | |
| "\U0001F680-\U0001F6FF" # Transport & Map Symbols | |
| "\U0001F1E0-\U0001F1FF" # Flags (iOS) | |
| "\U00002700-\U000027BF" # Dingbats | |
| "\U0001F900-\U0001F9FF" # Supplemental Symbols and Pictographs | |
| "\U00002600-\U000026FF" # Miscellaneous Symbols | |
| "\U00002B50-\U00002B55" # Miscellaneous Symbols and Pictographs | |
| "\U0001FA70-\U0001FAFF" # Symbols and Pictographs Extended-A | |
| "\U0001F700-\U0001F77F" # Alchemical Symbols | |
| "\U0001F780-\U0001F7FF" # Geometric Shapes Extended | |
| "\U0001F800-\U0001F8FF" # Supplemental Arrows-C | |
| "\U0001F900-\U0001F9FF" # Supplemental Symbols and Pictographs | |
| "\U0001FA00-\U0001FA6F" # Chess Symbols | |
| "]+", | |
| flags=re.UNICODE, | |
| ) | |
| def clean(text, stopwords): | |
| text = text.lower() # нижний регистр | |
| text = re.sub(r"http\S+", " ", text) # удаляем ссылки | |
| text = re.sub(r"@\w+", " ", text) # удаляем упоминания пользователей | |
| text = re.sub(r"#\w+", " ", text) # удаляем хэштеги | |
| text = re.sub(r"\d+", " ", text) # удаляем числа | |
| text = text.translate(str.maketrans("", "", string.punctuation)) | |
| text = re.sub(r"<.*?>", " ", text) # | |
| text = re.sub(r"[️«»—]", " ", text) | |
| text = re.sub(r"[^а-яё ]", " ", text) | |
| text = text.lower() | |
| text = emoji_pattern.sub(r"", text) | |
| text = " ".join([word for word in text.split() if word not in stopwords]) | |
| return text | |
| def predict_review(review): | |
| start_time = time.time() | |
| # Очистка и лемматизация текста | |
| clean_text = clean(review, stop_words) | |
| lem_text = lemmatize(clean_text) | |
| # Преобразование текста в TF-IDF представление | |
| X_new = vectorizer.transform([lem_text]) | |
| # Предсказание | |
| prediction = logreg.predict(X_new)[0] | |
| # Проверка допустимости предсказания | |
| if prediction not in rating_dict: | |
| rating = "Ошибка предсказания" | |
| else: | |
| rating = rating_dict[prediction] | |
| # Измерение времени | |
| end_time = time.time() | |
| elapsed_time = end_time - start_time | |
| print(f"Лейбл: {prediction}") | |
| print(f"Оценка отзыва: {rating}") | |
| print(f"Затраченное время: {elapsed_time:.6f} seconds") | |
| return prediction, rating, elapsed_time | |
| def preprocess_input(text): | |
| inputs = tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding=True) | |
| return inputs | |
| def predict_bert(text): | |
| start_time = time.time() | |
| model.eval() | |
| inputs = preprocess_input(text) | |
| # Move tensors to the correct device if using GPU | |
| inputs = {k: v.to('cpu') for k, v in inputs.items()} | |
| # Get model predictions | |
| with torch.no_grad(): | |
| outputs = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) | |
| # Since the output is already logits, no need to access outputs.logits | |
| predicted_class = outputs.argmax(dim=-1).item() | |
| end_time = time.time() | |
| elapsed_time = end_time - start_time | |
| return predicted_class, rating_dict[predicted_class], elapsed_time |