File size: 2,294 Bytes
13eb274
c4a5e63
 
 
db6d024
c4a5e63
 
 
db6d024
c4a5e63
 
 
 
 
 
 
13eb274
c4a5e63
 
13eb274
c4a5e63
57dab09
db6d024
c4a5e63
 
 
 
 
 
 
 
 
 
 
 
f22369d
 
57dab09
c4a5e63
 
 
 
 
57dab09
c4a5e63
db6d024
f22369d
c4a5e63
 
 
 
57dab09
f22369d
db6d024
54b6cf1
 
db6d024
57dab09
 
c4a5e63
57dab09
 
c4a5e63
 
 
57dab09
 
c4a5e63
 
9e9ef92
c4a5e63
9e9ef92
f22369d
c4a5e63
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import load_dataset
from datasets import DatasetDict
import os

# Load tokenizer and model
model = T5ForConditionalGeneration.from_pretrained("t5-base")
tokenizer = T5Tokenizer.from_pretrained("t5-base")

# Load JSON datasets from local files
data_files = {
    "train": "dementia_train_split.json",
    "validation": "dementia_validation_split.json",
    "test": "dementia_test_multilang.json"
}
dataset = load_dataset("json", data_files=data_files)

# Convert to DatasetDict (required for .map with remove_columns)
dataset = DatasetDict(dataset)

# Preprocessing function to tokenize inputs and outputs
def preprocess(example):
    prefix = "émotion: " if example.get("language", "en") == "fr" else "emotion: "
    input_enc = tokenizer(
        prefix + example["input"],
        padding="max_length",
        truncation=True,
        max_length=128
    )
    target_enc = tokenizer(
        example["response"],
        padding="max_length",
        truncation=True,
        max_length=128
    )
    input_enc["labels"] = target_enc["input_ids"]
    return input_enc

# Tokenize and clean up metadata
tokenized_dataset = dataset.map(
    preprocess,
    remove_columns=["input", "response", "emotion", "intent", "tags", "care_mode", "language", "difficulty", "is_dementia_related"]
)

# Define training arguments
args = TrainingArguments(
    output_dir="./model",
    num_train_epochs=4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss"
)

# Define the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"]
)

# Start training
trainer.train()

# Save and push the final model
trainer.save_model("./model")
tokenizer.save_pretrained("./model")

# Optional: Push to HF hub (requires `huggingface-cli login`)
if training_args.push_to_hub:
    trainer.push_to_hub()
    tokenizer.push_to_hub("obx0x3/empathy-dementia")

print("✅ Model trained and saved!")