Files changed (1) hide show
  1. README.md +43 -64
README.md CHANGED
@@ -1,65 +1,44 @@
1
- ---
2
- license: apache-2.0
3
- ---
4
- pip install transformers datasets torch scikit-learn
5
- import torch
6
  from datasets import load_dataset
7
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
8
- from sklearn.model_selection import train_test_split
9
- from sklearn.metrics import accuracy_score, f1_score
10
- def load_and_prepare_data():
11
- dataset = load_dataset("emotion")
12
- train_dataset = dataset["train"]
13
- test_dataset = dataset["test"]
14
- return train_dataset, test_dataset
15
- def tokenize_dataset(dataset):
16
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
17
- def tokenize_function(examples):
18
- return tokenizer(examples["text"], padding="max_length", truncation=True)
19
- tokenized_dataset = dataset.map(tokenize_function, batched=True)
20
- return tokenized_dataset
21
- def load_model():
22
- num_labels = 6
23
- model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=num_labels)
24
- return model
25
- def define_training_arguments():
26
- training_args = TrainingArguments(
27
- output_dir="./results",
28
- num_train_epochs=3,
29
- per_device_train_batch_size=16,
30
- per_device_eval_batch_size=64,
31
- warmup_steps=500,
32
- weight_decay=0.01,
33
- logging_dir="./logs",
34
- logging_steps=10,
35
- evaluation_strategy="epoch",
36
- save_strategy="epoch",
37
- load_best_model_at_end=True,
38
- metric_for_best_model="accuracy",
39
- greater_is_better=True,
40
- )
41
- return training_args
42
- def compute_metrics(eval_pred):
43
- logits, labels = eval_pred
44
- predictions = torch.argmax(torch.tensor(logits), dim=-1)
45
- accuracy = accuracy_score(labels, predictions)
46
- f1 = f1_score(labels, predictions, average="weighted")
47
- return {"accuracy": accuracy, "f1": f1}
48
- def main():
49
- train_dataset, test_dataset = load_and_prepare_data()
50
- tokenized_train_dataset = tokenize_dataset(train_dataset)
51
- tokenized_test_dataset = tokenize_dataset(test_dataset)
52
- model = load_model()
53
- training_args = define_training_arguments()
54
- trainer = Trainer(
55
- model=model,
56
- args=training_args,
57
- train_dataset=tokenized_train_dataset,
58
- eval_dataset=tokenized_test_dataset,
59
- compute_metrics=compute_metrics,
60
- )
61
- trainer.train()
62
- trainer.evaluate()
63
- trainer.save_model()
64
- if __name__ == "__main__":
65
- main()
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
 
 
 
2
  from datasets import load_dataset
3
+ import torch
4
+ from sklearn.metrics import classification_report, confusion_matrix
5
+
6
+ # Загружаем модель и токенизатор
7
+ model_name = 'your_model_name'
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
+
11
+ # Загружаем датасет
12
+ dataset = load_dataset('mnli', split='validation_matched[:1%]')
13
+
14
+ # Токенизация
15
+ def tokenize_function(examples):
16
+ return tokenizer(examples["premise"], examples["hypothesis"], truncation=True)
17
+
18
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
19
+ labels = tokenized_dataset['label']
20
+
21
+ # Готовим батчи для предсказаний
22
+ inputs = tokenized_dataset.remove_columns(['premise', 'hypothesis'])
23
+ inputs.set_format(type="torch")
24
+ loader = torch.utils.data.DataLoader(inputs, batch_size=8)
25
+
26
+ # Используем GPU, если доступно
27
+ device = torch.device("cuda") if torch.cuda.isavailable() else torch.device("cpu")
28
+ model.to(device)
29
+
30
+ # Получаем предсказания
31
+ preds = []
32
+ for batch in loader:
33
+ outputs = model(**batch.to(device))
34
+ preds.extend(outputs.logits.argmax(dim=-1).tolist())
35
+
36
+ predicted_labels = preds
37
+
38
+ # Оцениваем производительность
39
+ report = classification_report(labels, predicted_labels)
40
+ matrix = confusion_matrix(labels, predicted_labels)
41
+
42
+ print(report)
43
+ print("\nМатрица путаницы:")
44
+ print(matrix)