File size: 7,481 Bytes
9564ed2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq
import torch
from peft import LoraConfig, get_peft_model
import os
from tqdm import tqdm
import json
import random
from datasets import load_dataset
from datasets import Dataset, DatasetDict

system_message = "You are a helpful assistant who is an expert in estimating quality of translations."

output_template = '''
{
        "Accuracy Issues": [
                {
                        "Error Span": "",
                        "Error Explanation": "",
                        "Error Quality Category": "",
                        "Error Quality Tags": [],
                        "Error Severity": ""
                }
        ],
        "Accuracy Score": "",
        "Readability Issues": [
                {
                        "Error Span": "",
                        "Error Explanation": "",
                        "Error Quality Category": "",
                        "Error Quality Tags": [],
                        "Error Severity": ""
                }
        ],
        "Readability Score": ""
}'''

def create_conversation(input_sample, output_sample):
  return {
    "messages": [
      # {"role": "system", "content": system_message},
      {"role": "user", "content": input_sample},
      {"role": "assistant", "content": output_sample}
    ]
  }

data_path = (
    "/root/notebooks/MT_TQ/TQ/TQTune/labeled_data/parsed/"
)

json_files = [
    os.path.join(root, file)
    for root, _, files in os.walk(data_path)
    for file in files
    if file.endswith(".json") and "PLDL" in file
]

training_samples = []
for json_file in tqdm(json_files):
    with open(json_file, "r") as file:
        data = json.load(file)
    sampled_items = random.sample(data["data"], 20)
    training_samples.extend(sampled_items)

datapoints = []

for sample in training_samples:
    datapoint = {"input": {}}
    datapoint["input"]["src_text"] = sample["main_src_text"]
    datapoint["input"]["tgt_text"] = sample["tgt_text"]
    datapoint["input"]["src_prev"] = sample["tt_src_prev"]
    datapoint["input"]["src_next"] = sample["tt_src_next"]
    datapoint["input"]["tgt_prev"] = sample["tt_tgt_prev"]
    datapoint["input"]["tgt_next"] = sample["tt_tgt_next"]
    datapoint["input"]["src_lang"] = sample["src_lang"]
    datapoint["input"]["tgt_lang"] = sample["tgt_lang"]
    datapoint["evaluation"] = sample["labelers"][0]["annotation"]
    datapoints.append(datapoint)

def dataset_prep(datapoints, test_size=0.2):
    with open("prompts.txt") as file:
        template_string = file.read()
    
    random.shuffle(datapoints)
    
    split_index = int(len(datapoints) * (1 - test_size))
    train_datapoints = datapoints[:split_index]
    test_datapoints = datapoints[split_index:]
    
    def create_dataset(datapoints):
        dataset = []
        for datapoint in datapoints:
            src_text = datapoint['input']['src_text']
            tgt_text = datapoint['input']['tgt_text']
            src_prev = datapoint['input']['src_prev']
            src_next = datapoint['input']['src_next']        
            tgt_prev = datapoint['input']['tgt_prev']
            tgt_next = datapoint['input']['tgt_next']
            src_lang = datapoint['input']['src_lang']
            tgt_lang = datapoint['input']['tgt_lang']
            output   = datapoint['evaluation']
            del output["Confidence Level"]
            del output["Main Vs Alternate"]
            del output["Score"]

            if len(output['Accuracy Issues']) != 0 and len(output['Readability Issues']) != 0:
                item = template_string.format(src_text=src_text, tgt_text=tgt_text, 
                                              src_prev=src_prev, src_next=src_next, 
                                              tgt_prev=tgt_prev, tgt_next=tgt_next, 
                                              src_lang=src_lang, tgt_lang=tgt_lang,
                                              template=output_template)
                
                dataset.append(create_conversation(item, json.dumps(output)))
        
        return dataset
    
    train_set = create_dataset(train_datapoints)
    test_set = create_dataset(test_datapoints)
    
    return train_set, test_set

train_dataset, test_dataset = dataset_prep(datapoints)
dataset = {"train": train_dataset, "test": test_dataset}

def convert_to_hf_dataset(dataset):
    # Convert the train and test datasets into Hugging Face Dataset objects
    train_dataset = Dataset.from_list(dataset['train'])
    test_dataset = Dataset.from_list(dataset['test'])
    
    # Combine them into a DatasetDict
    hf_dataset = DatasetDict({
        'train': train_dataset,
        'test': test_dataset
    })
    
    return hf_dataset

# Convert your dataset into a Hugging Face Dataset object
hf_dataset = convert_to_hf_dataset(dataset)

# Now you can use hf_dataset for your machine learning tasks
print(hf_dataset)

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
device = torch.device("cuda:0")

# Hugging Face model id
model_id = "google/gemma-3-12b-it" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`

# Select model class based on id
if model_id == "google/gemma-3-12b-it":
    model_class = Gemma3ForConditionalGeneration
else:
    model_class = AutoModelForImageTextToText

torch_dtype = torch.bfloat16

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch_dtype,
    device_map="auto",  # Change from {'': 0} to "auto"
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_use_double_quant=True,
    bnb_8bit_quant_type='nf8',
    bnb_8bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_8bit_quant_storage=model_kwargs['torch_dtype'],
)

model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-12b-it") # Load the Instruction Tokenizer to use the official Gemma template

from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=128,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)

from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-12b-tq-model",
    max_seq_length=512,
    packing=True,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=1,
    save_strategy="epoch",
    learning_rate=2e-4,
    fp16=True if torch_dtype == torch.float16 else False,
    bf16=True if torch_dtype == torch.bfloat16 else False,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=True,
    report_to="tensorboard",
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": True,
    },
    ddp_find_unused_parameters=False,
    no_cuda=False,
)

from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=hf_dataset["train"],
    peft_config=peft_config,
    processing_class=tokenizer
)

trainer.train()
trainer.save_model()