Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoModelForImageClassification, AutoProcessor, Trainer, TrainingArguments | |
| from datasets import load_dataset | |
| # 1. Carregar o modelo e o processador pré-treinado | |
| model_name = "google/vit-base-patch16-224-in21k" | |
| model = AutoModelForImageClassification.from_pretrained(model_name, num_labels=3) | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| # 2. Carregar o dataset | |
| dataset = load_dataset("beans") | |
| # 3. Função de pré-processamento das imagens | |
| def preprocess_data(example): | |
| # Processar apenas a coluna "image" | |
| image = example['image'] | |
| # O processor transforma a imagem em um tensor que o modelo pode entender | |
| inputs = processor(images=image, return_tensors="pt") | |
| # O Trainer espera tensores puros, então convertemos | |
| pixel_values = inputs["pixel_values"].squeeze() # Remove dimensões extras | |
| labels = torch.tensor(example["labels"], dtype=torch.long) # Converte labels para tensor | |
| return {"pixel_values": pixel_values, "labels": labels} | |
| # 4. Aplicar o pré-processamento às imagens do dataset | |
| train_dataset = dataset["train"].map(preprocess_data, remove_columns=["image"]) | |
| eval_dataset = dataset["test"].map(preprocess_data, remove_columns=["image"]) | |
| # **Corrigir o formato do dataset** - Definir os formatos corretamente | |
| train_dataset.set_format(type="torch", columns=["pixel_values", "labels"]) | |
| eval_dataset.set_format(type="torch", columns=["pixel_values", "labels"]) | |
| # 5. Configurar os parâmetros de treinamento | |
| training_args = TrainingArguments( | |
| output_dir="./vit-finetuned", # Diretório para salvar o modelo treinado | |
| num_train_epochs=3, # Número de épocas para treinamento | |
| per_device_train_batch_size=8, # Tamanho do batch de treinamento | |
| evaluation_strategy="epoch", # Avaliar o modelo a cada época | |
| save_strategy="epoch", # Salvar o modelo a cada época | |
| save_total_limit=2 # Limitar o número de checkpoints salvos | |
| ) | |
| # 6. Configurar o Trainer | |
| trainer = Trainer( | |
| model=model, # O modelo treinado | |
| args=training_args, # Argumentos de treinamento | |
| train_dataset=train_dataset, # Dataset de treinamento | |
| eval_dataset=eval_dataset # Dataset de avaliação | |
| ) | |
| # 7. Iniciar o treinamento | |
| trainer.train() | |
| # 8. Salvar o modelo finetunado | |
| trainer.save_model("./vit-finetuned") | |