| | |
| | """train_model.ipynb |
| | |
| | Automatically generated by Colab. |
| | |
| | Original file is located at |
| | https://colab.research.google.com/drive/1BMInZz4vjJ1PfgTbbqIknpJYcbM5cwV0 |
| | """ |
| |
|
| | import torch |
| | import numpy as np |
| | from datasets import load_dataset |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments |
| |
|
| | print("Downloading dataset...") |
| | dataset = load_dataset("papluca/language-identification", split="train") |
| |
|
| | target_langs = {'en', 'fr', 'es', 'de'} |
| | filtered_dataset = dataset.filter(lambda example: example['labels'] in target_langs) |
| |
|
| | label2id = {"en": 0, "fr": 1, "es": 2, "de": 3} |
| | id2label = {0: "en", 1: "fr", 2: "es", 3: "de"} |
| |
|
| | model_ckpt = "distilbert-base-multilingual-cased" |
| | tokenizer = AutoTokenizer.from_pretrained(model_ckpt) |
| |
|
| | def preprocess(examples): |
| | tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=64) |
| | tokenized["labels"] = [label2id[lang] for lang in examples["labels"]] |
| | return tokenized |
| |
|
| | print("Preprocessing data...") |
| |
|
| | train_subset = filtered_dataset.shuffle(seed=42).select(range(1500)) |
| | tokenized_data = train_subset.map(preprocess, batched=True) |
| |
|
| | model = AutoModelForSequenceClassification.from_pretrained( |
| | model_ckpt, |
| | num_labels=4, |
| | id2label=id2label, |
| | label2id=label2id |
| | ) |
| |
|
| | args = TrainingArguments( |
| | output_dir="my_real_model", |
| | learning_rate=2e-5, |
| | per_device_train_batch_size=16, |
| | num_train_epochs=2, |
| | weight_decay=0.01, |
| | save_strategy="no", |
| | use_cpu=True |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=args, |
| | train_dataset=tokenized_data, |
| | tokenizer=tokenizer, |
| | ) |
| |
|
| | print("Starting training...") |
| | trainer.train() |
| |
|
| | print("Saving model to './production_model'...") |
| | trainer.save_model("production_model") |
| | print("Done!") |