|
|
|
|
|
"""NLP_GENERAL.ipynb |
|
|
|
|
|
Automatically generated by Colab. |
|
|
|
|
|
Original file is located at |
|
|
https://colab.research.google.com/drive/1g7CiQ8eJjVdDnZMoBWSOD01rHMVuQdC3 |
|
|
|
|
|
# Классификация |
|
|
|
|
|
## Библиотеки и зависимости |
|
|
""" |
|
|
|
|
|
!pip install pymorphy2 |
|
|
!pip install ufal.udpipe |
|
|
!pip install wget |
|
|
!pip install gensim |
|
|
!pip install umap-learn |
|
|
!pip install datashader |
|
|
!pip install bokeh |
|
|
!pip install holoviews |
|
|
!pip install yargy |
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
import seaborn as sns |
|
|
import pymorphy2 as mph |
|
|
import re |
|
|
import wget |
|
|
import sys |
|
|
from gensim.models import Word2Vec as w2v |
|
|
import logging |
|
|
import string |
|
|
import nltk |
|
|
from nltk import word_tokenize |
|
|
from nltk.corpus import stopwords |
|
|
import random |
|
|
import json |
|
|
import numpy as np |
|
|
import umap |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
from yargy import Parser, rule, and_, or_ |
|
|
from yargy.interpretation import fact, attribute |
|
|
from yargy.predicates import normalized, dictionary |
|
|
from yargy.pipelines import morph_pipeline |
|
|
from yargy.relations import main |
|
|
from IPython.display import display |
|
|
import spacy |
|
|
|
|
|
nltk.download('punkt') |
|
|
nltk.download('stopwords') |
|
|
sw = stopwords.words('russian') |
|
|
|
|
|
"""## Предобработка |
|
|
|
|
|
## 1. Предобработка текста |
|
|
|
|
|
|
|
|
* 1. ([Kaggle](https://www.kaggle.com/code/sudalairajkumar/getting-started-with-text-preprocessing)). |
|
|
* 2. (https://www.kaggle.com/code/abdmental01/text-preprocessing-nlp-steps-to-process-text)). |
|
|
* 3. (https://neptune.ai/blog/text-classification-tips-and-tricks-kaggle-competitions) |
|
|
|
|
|
Лемматизация |
|
|
|
|
|
|
|
|
--- |
|
|
""" |
|
|
|
|
|
patterns = "[A-Za-z0-9!#$%&'()*+/:;<=>?@[\]^_`{|}~—\"]+" |
|
|
morph = mph.MorphAnalyzer() |
|
|
|
|
|
def lemmatize(doc): |
|
|
doc = re.sub(patterns, ' ', doc) |
|
|
tokens = [] |
|
|
for token in doc.split(): |
|
|
if token: |
|
|
token = token.strip() |
|
|
token = morph.normal_forms(token)[0] |
|
|
tokens.append(token) |
|
|
return ' '.join(tokens) |
|
|
|
|
|
"""Наташа |
|
|
|
|
|
|
|
|
--- |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
topic_name = [] |
|
|
topic_one_to_one = [] |
|
|
Case = fact('Case', ['name']) |
|
|
|
|
|
def make_topic(topic: list, name: str): |
|
|
global topic_name |
|
|
|
|
|
topic_name.append(morph_pipeline(topic).interpretation( |
|
|
Case.name.const(name) |
|
|
).interpretation( |
|
|
Case |
|
|
) |
|
|
) |
|
|
|
|
|
def make_topic_one_to_one(topic: list): |
|
|
global topic_name |
|
|
|
|
|
return morph_pipeline(topic).interpretation( |
|
|
Case.name.normalized() |
|
|
).interpretation( |
|
|
Case |
|
|
) |
|
|
|
|
|
top_topic = [ |
|
|
(["окружность", "угол"], 'Геометрия'), |
|
|
|
|
|
(["деление", "множители"], 'Многочлен'), |
|
|
|
|
|
(["клетка", "закрасить"], 'Дирихле'), |
|
|
|
|
|
(["делится", "оканчивается"], 'Теория чисел'), |
|
|
|
|
|
(["способ", "разделить"], 'Комбинаторика'), |
|
|
|
|
|
(["последовательность", "разрешаться"], 'Инвариант'), |
|
|
|
|
|
(["сумма", "каждый", ], 'Оценка+Пример'), |
|
|
|
|
|
(['город', "ребро",], 'Графы') |
|
|
] |
|
|
|
|
|
for name_complaint in top_topic: |
|
|
make_topic(name_complaint[0], name_complaint[1]) |
|
|
topic_one_to_one.extend(list(name_complaint[0])) |
|
|
for columns in list(name_complaint[0]): |
|
|
data[columns] = np.NaN |
|
|
|
|
|
OTHERS = make_topic_one_to_one(topic_one_to_one) |
|
|
|
|
|
ALL = or_(*topic_name).interpretation(Case) |
|
|
OTHERS_ALL = or_(OTHERS).interpretation(Case) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Стоп слова""" |
|
|
|
|
|
|
|
|
|
|
|
def remove_stopwords(lines, sw=sw): |
|
|
res = [] |
|
|
for line in lines: |
|
|
original = line |
|
|
line = [w for w in line if w not in sw] |
|
|
if len(line) < 1: |
|
|
line = original |
|
|
res.append(line) |
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
"""Word2Vec""" |
|
|
|
|
|
|
|
|
|
|
|
random.shuffle(filtered_lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.save("word2vec.model") |
|
|
|
|
|
|
|
|
model = w2v.load("/content/drive/MyDrive/Проекты/Medsi/Models/word2vec.model") |
|
|
|
|
|
|
|
|
merge_data_filter_2.illness_hostory = merge_data_filter_2.illness_hostory.apply(lemmatize) |
|
|
|
|
|
|
|
|
for i in range(100): |
|
|
merge_data_filter_2[f'vector_{i}'] = 0 |
|
|
|
|
|
for j, text in enumerate(merge_data_filter_2['illness_hostory']): |
|
|
vec = np.zeros(100) |
|
|
lens = 0 |
|
|
for word in word_tokenize(text): |
|
|
try: |
|
|
vec += model.wv[word] |
|
|
lens += 1 |
|
|
except KeyError: |
|
|
continue |
|
|
|
|
|
vec /= lens |
|
|
for i in range(100): |
|
|
merge_data_filter_2.iloc[j, 103+i] = vec[i] |
|
|
|
|
|
"""Umap""" |
|
|
|
|
|
import umap.plot |
|
|
|
|
|
mapper = umap.UMAP(densmap=True).fit(X) |
|
|
umap.plot.points(mapper) |
|
|
|
|
|
"""Фильтрация пунктуации""" |
|
|
|
|
|
def remove_punctuation(text): |
|
|
translator = str.maketrans('', '', string.punctuation) |
|
|
return text.translate(translator) |
|
|
|
|
|
"""Облако слов""" |
|
|
|
|
|
from wordcloud import WordCloud |
|
|
|
|
|
for topic in data.topic.unique(): |
|
|
df = data[data.topic == topic] |
|
|
text = ' '.join(df['new_task']) |
|
|
text_tokens = word_tokenize(text) |
|
|
|
|
|
cloud = WordCloud(stopwords=stop_words, |
|
|
background_color='white').generate(' '.join(text_tokens)) |
|
|
plt.imshow(cloud) |
|
|
plt.axis('off') |
|
|
plt.title(topic) |
|
|
plt.show() |
|
|
|
|
|
"""N-граммы""" |
|
|
|
|
|
k = 30 |
|
|
n = 2 |
|
|
for topic in data.topic.unique(): |
|
|
df = data[data.topic == topic] |
|
|
words = ' '.join(df.new_task_pros) |
|
|
words = ' '.join(list(filter(lambda x: len(x) >= 2, (words.split())))) |
|
|
tokens = nltk.word_tokenize(words) |
|
|
|
|
|
ngrams_list = list(ngrams(tokens, n)) |
|
|
freq_dist = dict(FreqDist(ngrams_list)) |
|
|
sorted_data = sorted(freq_dist.items(), key=lambda x: -x[1]) |
|
|
|
|
|
y_labels = [str(key) for key, _ in sorted_data][:k][::-1] |
|
|
x_values = [value for _, value in sorted_data][:k][::-1] |
|
|
|
|
|
plt.barh(y_labels, x_values) |
|
|
plt.xlabel('Значение') |
|
|
plt.ylabel('Кортежи') |
|
|
plt.title(topic) |
|
|
plt.show() |
|
|
|
|
|
"""TF-IDF""" |
|
|
|
|
|
def vect_tfidf(text): |
|
|
return vectorizer.transform([text]).toarray() |
|
|
|
|
|
vectorizer = TfidfVectorizer(max_features=5000, min_df=3) |
|
|
X = vectorizer.fit_transform(learn_tf_idf) |
|
|
|
|
|
"""Tenserflow token""" |
|
|
|
|
|
vocab_size = 20000 |
|
|
trunc_type = 'post' |
|
|
padding_type = 'post' |
|
|
embedding_dim = 128 |
|
|
max_length = 120 |
|
|
oov_tok = '' |
|
|
|
|
|
text = data['new_task'] |
|
|
labels = data['y'] |
|
|
tokenizer = Tokenizer( |
|
|
num_words=vocab_size, |
|
|
filters='!"#$%&()*+,-./:;<=>?@[\]^_`{|}~\t\n', |
|
|
lower=True, |
|
|
oov_token=oov_tok |
|
|
) |
|
|
|
|
|
tokenizer.fit_on_texts(text) |
|
|
train_sequences = tokenizer.texts_to_sequences(text) |
|
|
train_padded = pad_sequences( |
|
|
train_sequences, |
|
|
maxlen=max_length, |
|
|
padding=padding_type, |
|
|
truncating=trunc_type |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_sequences = tokenizer.texts_to_sequences(data.new_task) |
|
|
train_padded = pad_sequences(train_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type) |
|
|
|
|
|
|
|
|
for i in tqdm(range(max_length)): |
|
|
data[f"Tokens f.{i + 1}"] = train_padded[:, i] |
|
|
|
|
|
"""## Finetune Bert""" |
|
|
|
|
|
!pip install transformers |
|
|
!pip install accelerate -U |
|
|
|
|
|
import torch |
|
|
import pandas as pd |
|
|
from transformers import AutoModelForSequenceClassification |
|
|
from transformers import BertTokenizerFast |
|
|
from transformers import TrainingArguments |
|
|
import torch, os |
|
|
import pandas as pd |
|
|
from transformers import pipeline, BertForSequenceClassification, BertTokenizerFast |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
import os |
|
|
import re |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import warnings |
|
|
import numpy as np |
|
|
import evaluate |
|
|
|
|
|
metric = evaluate.load("f1") |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
dataset = dataset[['task', 'topic']] |
|
|
dataset.rename(columns={'task': 'text', |
|
|
'topic': 'labels'}, |
|
|
inplace=True) |
|
|
NUM_LABELS = len(dataset.labels.unique()) |
|
|
|
|
|
id2label = {id: label for id, label in enumerate(dataset.labels.unique())} |
|
|
|
|
|
label2id = {label: id for id, label in enumerate(dataset.labels.unique())} |
|
|
|
|
|
|
|
|
tokenizer = BertTokenizerFast.from_pretrained('blanchefort/rubert-base-cased-sentiment') |
|
|
model = BertForSequenceClassification.from_pretrained('blanchefort/rubert-base-cased-sentiment', |
|
|
num_labels=NUM_LABELS, id2label=id2label, |
|
|
label2id=label2id, |
|
|
ignore_mismatched_sizes=True) |
|
|
|
|
|
train_encodings = tokenizer(list(X_train), truncation=True, padding=True) |
|
|
val_encodings = tokenizer(list(X_val), truncation=True, padding=True) |
|
|
test_encodings = tokenizer(list(X_test), truncation=True, padding=True) |
|
|
|
|
|
class DataLoader(Dataset): |
|
|
def __init__(self, encodings, labels): |
|
|
self.encodings = encodings |
|
|
self.labels = labels |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} |
|
|
|
|
|
item['labels'] = torch.tensor(self.labels[idx]) |
|
|
return item |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.labels) |
|
|
|
|
|
train_dataloader = DataLoader(train_encodings, list(y_train)) |
|
|
val_dataloader = DataLoader(val_encodings, list(y_val)) |
|
|
test_dataset = DataLoader(test_encodings, list(y_test)) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataloader, |
|
|
eval_dataset=val_dataloader, |
|
|
compute_metrics=compute_metrics |
|
|
) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
def predict(text): |
|
|
inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt").to("cuda") |
|
|
|
|
|
outputs = model(**inputs) |
|
|
|
|
|
probs = outputs[0].softmax(1) |
|
|
pred_label_idx = probs.argmax() |
|
|
pred_label = model.config.id2label[pred_label_idx.item()] |
|
|
|
|
|
return probs, pred_label_idx, pred_label |
|
|
|
|
|
|
|
|
text = input() |
|
|
predict(text) |
|
|
|
|
|
"""## Text Classification: All Tips and Tricks from 5 Kaggle Competitions, |
|
|
|
|
|
1. Оптимизация памяти при работе с большими датасетами |
|
|
|
|
|
Использование Dask для чтения и обработки данных: https://dask.org/ |
|
|
|
|
|
Использование cuDF для ускоренной обработки данных на GPU: https://docs.rapids.ai/api/cudf/stable/ |
|
|
|
|
|
Конвертация данных в формат Parquet: https://parquet.apache.org/ |
|
|
|
|
|
Конвертация данных в формат Feather: https://arrow.apache.org/docs/python/feather.html |
|
|
|
|
|
2. Методы увеличения данных (Data Augmentation) |
|
|
|
|
|
Замена слов синонимами для увеличения данных: https://towardsdatascience.com/data-augmentation-in-nlp-2801a34dfc28 |
|
|
|
|
|
Добавление шума в тексты для обучения RNN: https://arxiv.org/abs/1703.02573 |
|
|
|
|
|
Перевод текста на другие языки и обратно для создания новых примеров: https://arxiv.org/abs/1511.06709 |
|
|
|
|
|
3. Исследование данных и получение инсайтов |
|
|
|
|
|
Простая разведывательная аналитика (EDA) для твитов: https://www.kaggle.com/code/ashishpatel26/simple-eda-for-tweets |
|
|
|
|
|
EDA для данных Quora: https://www.kaggle.com/code/sudalairajkumar/simple-eda-for-quora-question-pairs |
|
|
|
|
|
Полный EDA для данных Stack Exchange: https://www.kaggle.com/code/ashishpatel26/complete-eda-with-stack-exchange-data |
|
|
|
|
|
Предыдущая статья автора о EDA для обработки естественного языка: https://neptune.ai/blog/exploratory-data-analysis-nlp |
|
|
|
|
|
4. Очистка данных |
|
|
|
|
|
Использование TextBlob для исправления орфографических ошибок: https://textblob.readthedocs.io/en/dev/ |
|
|
|
|
|
Предобработка для GloVe (часть 1): https://www.kaggle.com/code/ashishpatel26/preprocessing-for-glove-part-1 |
|
|
|
|
|
Предобработка для GloVe (часть 2): https://www.kaggle.com/code/ashishpatel26/preprocessing-for-glove-part-2 |
|
|
|
|
|
5. Представление текста |
|
|
|
|
|
Комбинирование предварительно обученных векторов для лучшего представления текста и уменьшения количества неизвестных слов: https://www.kaggle.com/code/ashishpatel26/combining-pre-trained-vectors |
|
|
|
|
|
Использование Universal Sentence Encoder для генерации признаков на уровне предложений: https://tfhub.dev/google/universal-sentence-encoder/4 |
|
|
|
|
|
Три метода комбинирования эмбеддингов: https://www.kaggle.com/code/ashishpatel26/3-methods-to-combine-embeddings |
|
|
|
|
|
6. Архитектура модели |
|
|
|
|
|
Стекирование двух слоев LSTM/GRU для улучшения производительности: https://www.kaggle.com/code/ashishpatel26/stacking-2-layers-of-lstm-gru-networks |
|
|
|
|
|
7. Функции потерь |
|
|
|
|
|
Использование фокальной функции потерь для несбалансированных данных: https://arxiv.org/abs/1708.02002 |
|
|
|
|
|
Пользовательская функция потерь "mimic loss", использованная в соревновании Jigsaw: https://www.kaggle.com/code/ashishpatel26/custom-mimic-loss-jigsaw |
|
|
|
|
|
Пользовательская функция потерь MTL, использованная в соревновании Jigsaw: https://www.kaggle.com/code/ashishpatel26/mtl-custom-loss-jigsaw |
|
|
|
|
|
8. Оптимизаторы |
|
|
|
|
|
Использование Adam с прогревом (warmup): https://www.kaggle.com/code/ashishpatel26/adam-with-warmup |
|
|
|
|
|
Использование BertAdam для моделей на основе BERT: https://www.kaggle.com/code/ashishpatel26/bert-adam |
|
|
|
|
|
Использование Rectified Adam для стабилизации обучения и ускорения сходимости: https://arxiv.org/abs/1908.03265 |
|
|
|
|
|
9. Методы обратного вызова (Callbacks) |
|
|
|
|
|
Контрольная точка модели для мониторинга и сохранения весов: https://www.kaggle.com/code/ashishpatel26/model-checkpoint |
|
|
|
|
|
Планировщик скорости обучения для изменения скорости обучения на основе производительности модели: https://www.kaggle.com/code/ashishpatel26/learning-rate-scheduler |
|
|
|
|
|
Простые пользовательские обратные вызовы с использованием lambda-функций: https://www.kaggle.com/code/ashishpatel26/simple-custom-callbacks |
|
|
|
|
|
Пользовательская контрольная точка: https://www.kaggle.com/code/ashishpatel26/custom-checkpointing |
|
|
|
|
|
Создание собственных обратных вызовов для различных случаев использования: https://www.kaggle.com/code/ashishpatel26/building-custom-callbacks |
|
|
|
|
|
Уменьшение на плато для снижения скорости обучения, когда метрика перестает улучшаться: https://www.kaggle.com/code/ashishpatel26/reduce-on-plateau |
|
|
|
|
|
Раннее прекращение обучения при отсутствии улучшений: https://www.kaggle.com/code/ashishpatel26/early-stopping |
|
|
|
|
|
Снимок ансамблирования для получения различных контрольных точек модели в одном обучении: https://www.kaggle.com/code/ashishpatel26/snapshot-ensembling |
|
|
|
|
|
Быстрое геометрическое ансамблирование: https://www.kaggle.com/code/ashishpatel26/fast-geometric-ensembling |
|
|
|
|
|
Стохастическое усреднение весов (SWA): https://www.kaggle.com/code/ashishpatel26/stochastic-weight-averaging |
|
|
|
|
|
Динамическое уменьшение скорости обучения: https://www.kaggle.com/code/ashishpatel26/dynamic-learning-rate-decay |
|
|
|
|
|
10. Оценка и кросс-валидация |
|
|
|
|
|
K-кратная кросс-валидация: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html |
|
|
|
|
|
Стратифицированная K-кратная кросс-валидация: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html |
|
|
|
|
|
Групповая K-кратная кросс-валидация: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GroupKFold.html |
|
|
|
|
|
Адвенсариальная валидация для проверки сходства распределений обучающего и тестового наборов: https://www.kaggle.com/code/ashishpatel26/adversarial-validation |
|
|
|
|
|
Анализ различных стратегий кросс-валидации: https://www.kaggle.com/code/ashishpatel26/cv-analysis-different-strategies |
|
|
|
|
|
11. Трюки для ускорения выполнения |
|
|
|
|
|
Сортировка последовательностей по длине для экономии времени выполнения и улучшения производительности: https://www.kaggle.com/code/ashishpatel26/sequence-bucketing |
|
|
|
|
|
Использование только начала и конца предложений, если длина превышает 512 токенов: https://www.kaggle.com/code/ashishpatel26/head-tail-trick |
|
|
|
|
|
Эффективное использование GPU: https://www.kaggle.com/code/ashishpatel26/use-gpu-efficiently |
|
|
|
|
|
Очистка памяти Keras: https://www.kaggle.com/code/ashishpatel26/free-keras-memory |
|
|
|
|
|
Сохранение и загрузка моделей для экономии времени и памяти: https://www.kaggle.com/code/ashishpatel26/save-load-models |
|
|
|
|
|
Не сохранять эмбеддинги в решениях на основе RNN: https://www.kaggle.com/code/ashishpatel26/dont-save-embedding-rnn |
|
|
|
|
|
Загрузка векторов word2vec без ключевых векторов: https://www.kaggle.com/code/ashishpatel26/load-word2vec-without-key-vectors |
|
|
|
|
|
12. Ансамблирование моделей |
|
|
|
|
|
Взвешенное среднее ансамблирование: https://www.kaggle.com/code/ashishpatel26/weighted-average-ensemble |
|
|
|
|
|
Стекированное обобщение (stacked generalization) ансамблирование: https://www.kaggle.com/code/ashishpatel26/stacked-generalization-ensemble |
|
|
|
|
|
Предсказания вне обучающего набора (out-of-fold predictions): https://www.kaggle.com/code/ashishpatel26/out-of-fold-predictions |
|
|
|
|
|
Смешивание с линейной регрессией: https://www.kaggle.com/code/ashishpatel26/blending-linear-regression |
|
|
|
|
|
Использование Optuna для определения весов смешивания: https://optuna.org/ |
|
|
|
|
|
Среднее по степени (power average) ансамблирование: https://www.kaggle.com/code/ashishpatel26/power-average-ensemble |
|
|
|
|
|
Стратегия смешивания с использованием степени 3.5: https://www.kaggle.com/code/ashishpatel26/power-3-5-blending-strategy |
|
|
|
|
|
# Генерация |
|
|
|
|
|
📌 Когда использовать что |
|
|
|
|
|
| Сценарий | Подход | |
|
|
| ---------------------------------------------------- | ---------------------------------------------- | |
|
|
| Маленькие датасеты, учебные задачи | RNN / LSTM | |
|
|
| Длинные последовательности, умеренные ресурсы | LSTM (для стабильности) или GRU (для скорости) | |
|
|
| Требуется копирование или внимание к части входа | RNN + Attention | |
|
|
| Лучшее качество, много данных и ресурсов | Полное дообучение трансформеров | |
|
|
| Большая модель, но мало памяти (например, 16 ГБ GPU) | LoRA / QLoRA | |
|
|
| Несколько задач на одной базе | Adapters или Prefix Tuning | |
|
|
| Небольшой датасет, few-shot или zero-shot | Prompt Tuning / Soft Prompts | |
|
|
|
|
|
https://www.kaggle.com/code/purvasingh/text-generation-via-rnn-and-lstms-pytorch |
|
|
|
|
|
https://www.kaggle.com/code/neerajmohan/finetuning-large-language-models-using-qlora |
|
|
|
|
|
https://www.kaggle.com/code/thebrownviking20/intro-to-recurrent-neural-networks-lstm-gru?utm_source=chatgpt.com |
|
|
""" |
|
|
|
|
|
from transformers import BertTokenizerFast, BertForSequenceClassification, Trainer, TrainingArguments |
|
|
from torch.utils.data import Dataset |
|
|
import torch |
|
|
import evaluate |
|
|
import warnings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="./results", |
|
|
num_train_epochs=3, |
|
|
per_device_train_batch_size=8, |
|
|
per_device_eval_batch_size=64, |
|
|
warmup_steps=500, |
|
|
weight_decay=0.01, |
|
|
logging_dir='./logs', |
|
|
logging_steps=10, |
|
|
evaluation_strategy="steps", |
|
|
eval_steps=500, |
|
|
save_steps=500, |
|
|
save_total_limit=2 |
|
|
) |
|
|
|
|
|
|
|
|
def compute_metrics(pred): |
|
|
labels = pred.label_ids |
|
|
preds = pred.predictions.argmax(-1) |
|
|
f1 = metric.compute(predictions=preds, references=labels, average="weighted") |
|
|
return { |
|
|
'f1': f1["f1"], |
|
|
} |
|
|
|
|
|
|