File size: 2,454 Bytes
a2e27b6
 
e211b11
 
a2e27b6
e211b11
 
a2e27b6
e211b11
a2e27b6
e211b11
 
a2e27b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e211b11
a2e27b6
 
 
 
 
 
e211b11
 
a2e27b6
e211b11
a2e27b6
 
 
 
e211b11
 
a2e27b6
e211b11
 
a2e27b6
e211b11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from transformers import AutoModelForImageClassification, AutoProcessor, Trainer, TrainingArguments
from datasets import load_dataset

# 1. Carregar o modelo e o processador pré-treinado
model_name = "google/vit-base-patch16-224-in21k"
model = AutoModelForImageClassification.from_pretrained(model_name, num_labels=3)
processor = AutoProcessor.from_pretrained(model_name)

# 2. Carregar o dataset
dataset = load_dataset("beans")

# 3. Função de pré-processamento das imagens
def preprocess_data(example):
    # Processar apenas a coluna "image"
    image = example['image']
    
    # O processor transforma a imagem em um tensor que o modelo pode entender
    inputs = processor(images=image, return_tensors="pt")
    
    # O Trainer espera tensores puros, então convertemos
    pixel_values = inputs["pixel_values"].squeeze()  # Remove dimensões extras
    labels = torch.tensor(example["labels"], dtype=torch.long)  # Converte labels para tensor

    return {"pixel_values": pixel_values, "labels": labels}

# 4. Aplicar o pré-processamento às imagens do dataset
train_dataset = dataset["train"].map(preprocess_data, remove_columns=["image"])
eval_dataset = dataset["test"].map(preprocess_data, remove_columns=["image"])

# **Corrigir o formato do dataset** - Definir os formatos corretamente
train_dataset.set_format(type="torch", columns=["pixel_values", "labels"])
eval_dataset.set_format(type="torch", columns=["pixel_values", "labels"])

# 5. Configurar os parâmetros de treinamento
training_args = TrainingArguments(
    output_dir="./vit-finetuned",         # Diretório para salvar o modelo treinado
    num_train_epochs=3,                   # Número de épocas para treinamento
    per_device_train_batch_size=8,        # Tamanho do batch de treinamento
    evaluation_strategy="epoch",          # Avaliar o modelo a cada época
    save_strategy="epoch",                # Salvar o modelo a cada época
    save_total_limit=2                    # Limitar o número de checkpoints salvos
)

# 6. Configurar o Trainer
trainer = Trainer(
    model=model,                          # O modelo treinado
    args=training_args,                   # Argumentos de treinamento
    train_dataset=train_dataset,          # Dataset de treinamento
    eval_dataset=eval_dataset             # Dataset de avaliação
)

# 7. Iniciar o treinamento
trainer.train()

# 8. Salvar o modelo finetunado
trainer.save_model("./vit-finetuned")