| import torch
|
| import evaluate
|
| import numpy as np
|
| from datasets import load_dataset
|
| from transformers import (
|
| AutoImageProcessor,
|
| AutoModelForImageClassification,
|
| TrainingArguments,
|
| Trainer,
|
| )
|
|
|
|
|
|
|
|
|
| ds = load_dataset("Nagabu/HAM10000")
|
|
|
|
|
|
|
| split = ds["train"].train_test_split(test_size=0.1)
|
| train_ds = split["train"]
|
| test_ds = split["test"]
|
|
|
|
|
|
|
| model_name = "google/vit-base-patch16-224-in21k"
|
| processor = AutoImageProcessor.from_pretrained(model_name)
|
|
|
|
|
| labels = train_ds.features["label"].names
|
| id2label = {i: label for i, label in enumerate(labels)}
|
| label2id = {label: i for i, label in enumerate(labels)}
|
|
|
|
|
| def transform(example_batch):
|
|
|
| inputs = processor([x for x in example_batch["image"]], return_tensors="pt")
|
|
|
| inputs["label"] = example_batch["label"]
|
| return inputs
|
|
|
|
|
| train_ds.set_transform(transform)
|
| test_ds.set_transform(transform)
|
|
|
|
|
|
|
|
|
|
|
| model = AutoModelForImageClassification.from_pretrained(
|
| model_name,
|
| num_labels=len(labels),
|
| id2label=id2label,
|
| label2id=label2id,
|
| ignore_mismatched_sizes=True,
|
| )
|
|
|
|
|
| accuracy = evaluate.load("accuracy")
|
|
|
|
|
| def compute_metrics(eval_pred):
|
| predictions, true_labels = eval_pred
|
| predictions = np.argmax(predictions, axis=1)
|
| return accuracy.compute(predictions=predictions, references=true_labels)
|
|
|
|
|
| training_args = TrainingArguments(
|
| output_dir="./meu-modelo-ham10000",
|
| per_device_train_batch_size=8,
|
| num_train_epochs=3,
|
| save_strategy="epoch",
|
| eval_strategy="epoch",
|
| logging_strategy="epoch",
|
| load_best_model_at_end=True,
|
| metric_for_best_model="accuracy",
|
| remove_unused_columns=False,
|
| )
|
|
|
|
|
| trainer = Trainer(
|
| model=model,
|
| args=training_args,
|
| train_dataset=train_ds,
|
| eval_dataset=test_ds,
|
| tokenizer=processor,
|
| compute_metrics=compute_metrics,
|
| )
|
|
|
|
|
|
|
| print("Iniciando o treinamento...")
|
| trainer.train()
|
|
|
|
|
|
|
| print("Treinamento concluído. Salvando o modelo final...")
|
| trainer.save_model("./meu-modelo-ham10000-final")
|
| print("Modelo salvo com sucesso na pasta 'meu-modelo-ham10000-final'") |