Spaces:
Runtime error
Runtime error
| from transformers import ( | |
| AutoModelForImageClassification, | |
| AutoImageProcessor, | |
| TrainingArguments, | |
| Trainer, | |
| ) | |
| from datasets import load_dataset | |
| import os | |
| def train(): | |
| # Load dataset | |
| dataset = load_dataset("ylecun/mnist") | |
| # Load processor and apply preprocessing to the dataset | |
| processor = AutoImageProcessor.from_pretrained("SupremoUGH/image-classification-model") | |
| def process(examples): | |
| images = [img.convert("RGB") for img in examples["image"]] | |
| inputs = processor(images=images, return_tensors="pt") | |
| inputs["labels"] = examples["label"] | |
| return inputs | |
| dataset.set_transform(process) # Sometimes `map` instead of `set_transform` | |
| # Load model and train it with certain training arguments | |
| model = AutoModelForImageClassification.from_pretrained("SupremoUGH/image-classification-model") | |
| training_args = TrainingArguments( | |
| output_dir="./results", | |
| remove_unused_columns=False, # Preserve input data | |
| per_device_train_batch_size=16, # Reduce batch size for efficiency | |
| eval_strategy="steps", | |
| num_train_epochs=3, | |
| fp16=False, # Disable fp16 mixed precision | |
| save_steps=500, | |
| eval_steps=500, | |
| logging_steps=100, | |
| learning_rate=2e-4, | |
| push_to_hub=False, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"], # Sometimes called "validation" | |
| ) | |
| trainer.train() | |
| # Save fine-tuned model | |
| save_dir = "./saved_model" | |
| os.makedirs(save_dir, exist_ok=True) | |
| model.save_pretrained(save_dir) | |
| print(f"Model saved to {save_dir}") | |
| if __name__ == "__main__": | |
| train() |