innocentpeter commited on
Commit
00e5e26
·
verified ·
1 Parent(s): e719870

Update training/train_trenslation.py

Browse files
Files changed (1) hide show
  1. training/train_trenslation.py +55 -52
training/train_trenslation.py CHANGED
@@ -1,52 +1,55 @@
1
- import os
2
- import json
3
- from datasets import load_dataset
4
- from transformers import MarianTokenizer, MarianMTModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
5
-
6
- MODEL_NAME = "Helsinki-NLP/opus-mt-ha-en" # Hausa-English base model
7
- OUTPUT_DIR = "./training/outputs/model"
8
-
9
- def train_from_jsonl(jsonl_path):
10
- dataset = load_dataset("json", data_files={"train": jsonl_path}, split="train")
11
-
12
- # Train/validation split
13
- dataset = dataset.train_test_split(test_size=0.1)
14
-
15
- tokenizer = MarianTokenizer.from_pretrained(MODEL_NAME)
16
- model = MarianMTModel.from_pretrained(MODEL_NAME)
17
-
18
- def preprocess(batch):
19
- inputs = tokenizer(batch["src"], truncation=True, padding="max_length", max_length=128)
20
- targets = tokenizer(batch["tgt"], truncation=True, padding="max_length", max_length=128)
21
- inputs["labels"] = targets["input_ids"]
22
- return inputs
23
-
24
- tokenized = dataset.map(preprocess, batched=True, remove_columns=["src", "tgt"])
25
- data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
26
-
27
- training_args = Seq2SeqTrainingArguments(
28
- output_dir=OUTPUT_DIR,
29
- evaluation_strategy="epoch",
30
- learning_rate=5e-5,
31
- per_device_train_batch_size=8,
32
- per_device_eval_batch_size=8,
33
- num_train_epochs=3,
34
- weight_decay=0.01,
35
- save_total_limit=2,
36
- predict_with_generate=True,
37
- logging_dir="./training/logs",
38
- )
39
-
40
- trainer = Seq2SeqTrainer(
41
- model=model,
42
- args=training_args,
43
- train_dataset=tokenized["train"],
44
- eval_dataset=tokenized["test"],
45
- tokenizer=tokenizer,
46
- data_collator=data_collator,
47
- )
48
-
49
- trainer.train()
50
- trainer.save_model(OUTPUT_DIR)
51
- tokenizer.save_pretrained(OUTPUT_DIR)
52
- print("✅ Training complete. Model saved at", OUTPUT_DIR)
 
 
 
 
1
+ # voice_translator/training/train_translation.py
2
+
3
+ import os
4
+ from datasets import load_dataset, Dataset
5
+ from transformers import (
6
+ MarianTokenizer,
7
+ MarianMTModel,
8
+ Seq2SeqTrainingArguments,
9
+ Seq2SeqTrainer,
10
+ DataCollatorForSeq2Seq,
11
+ )
12
+
13
+ MODEL_NAME = "Helsinki-NLP/opus-mt-mul-en"
14
+ OUTPUT_DIR = "./training/outputs/model"
15
+
16
+ def train_from_jsonl(file_path):
17
+ # Load dataset
18
+ dataset = load_dataset("json", data_files=file_path, split="train")
19
+
20
+ tokenizer = MarianTokenizer.from_pretrained(MODEL_NAME)
21
+ model = MarianMTModel.from_pretrained(MODEL_NAME)
22
+
23
+ def preprocess(batch):
24
+ inputs = tokenizer(batch["src"], truncation=True, padding="max_length", max_length=128)
25
+ targets = tokenizer(batch["tgt"], truncation=True, padding="max_length", max_length=128)
26
+ inputs["labels"] = targets["input_ids"]
27
+ return inputs
28
+
29
+ tokenized = dataset.map(preprocess, batched=True)
30
+
31
+ collator = DataCollatorForSeq2Seq(tokenizer, model=model)
32
+
33
+ args = Seq2SeqTrainingArguments(
34
+ output_dir=OUTPUT_DIR,
35
+ evaluation_strategy="no",
36
+ learning_rate=5e-5,
37
+ per_device_train_batch_size=8,
38
+ num_train_epochs=3,
39
+ save_total_limit=1,
40
+ predict_with_generate=True,
41
+ )
42
+
43
+ trainer = Seq2SeqTrainer(
44
+ model=model,
45
+ args=args,
46
+ train_dataset=tokenized,
47
+ tokenizer=tokenizer,
48
+ data_collator=collator,
49
+ )
50
+
51
+ trainer.train()
52
+ trainer.save_model(OUTPUT_DIR)
53
+ tokenizer.save_pretrained(OUTPUT_DIR)
54
+
55
+ return f"✅ Model trained and saved to {OUTPUT_DIR}"