from transformers import AutoProcessor, Gemma3ForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq import torch from peft import LoraConfig, get_peft_model import os from tqdm import tqdm import json import random from datasets import load_dataset from datasets import Dataset, DatasetDict system_message = "You are a helpful assistant who is an expert in estimating quality of translations." output_template = ''' { "Accuracy Issues": [ { "Error Span": "", "Error Explanation": "", "Error Quality Category": "", "Error Quality Tags": [], "Error Severity": "" } ], "Accuracy Score": "", "Readability Issues": [ { "Error Span": "", "Error Explanation": "", "Error Quality Category": "", "Error Quality Tags": [], "Error Severity": "" } ], "Readability Score": "" }''' def create_conversation(input_sample, output_sample): return { "messages": [ # {"role": "system", "content": system_message}, {"role": "user", "content": input_sample}, {"role": "assistant", "content": output_sample} ] } data_path = ( "/root/notebooks/MT_TQ/TQ/TQTune/labeled_data/parsed/" ) json_files = [ os.path.join(root, file) for root, _, files in os.walk(data_path) for file in files if file.endswith(".json") and "PLDL" in file ] training_samples = [] for json_file in tqdm(json_files): with open(json_file, "r") as file: data = json.load(file) sampled_items = random.sample(data["data"], 20) training_samples.extend(sampled_items) datapoints = [] for sample in training_samples: datapoint = {"input": {}} datapoint["input"]["src_text"] = sample["main_src_text"] datapoint["input"]["tgt_text"] = sample["tgt_text"] datapoint["input"]["src_prev"] = sample["tt_src_prev"] datapoint["input"]["src_next"] = sample["tt_src_next"] datapoint["input"]["tgt_prev"] = sample["tt_tgt_prev"] datapoint["input"]["tgt_next"] = sample["tt_tgt_next"] datapoint["input"]["src_lang"] = sample["src_lang"] datapoint["input"]["tgt_lang"] = sample["tgt_lang"] datapoint["evaluation"] = sample["labelers"][0]["annotation"] datapoints.append(datapoint) def dataset_prep(datapoints, test_size=0.2): with open("prompts.txt") as file: template_string = file.read() random.shuffle(datapoints) split_index = int(len(datapoints) * (1 - test_size)) train_datapoints = datapoints[:split_index] test_datapoints = datapoints[split_index:] def create_dataset(datapoints): dataset = [] for datapoint in datapoints: src_text = datapoint['input']['src_text'] tgt_text = datapoint['input']['tgt_text'] src_prev = datapoint['input']['src_prev'] src_next = datapoint['input']['src_next'] tgt_prev = datapoint['input']['tgt_prev'] tgt_next = datapoint['input']['tgt_next'] src_lang = datapoint['input']['src_lang'] tgt_lang = datapoint['input']['tgt_lang'] output = datapoint['evaluation'] del output["Confidence Level"] del output["Main Vs Alternate"] del output["Score"] if len(output['Accuracy Issues']) != 0 and len(output['Readability Issues']) != 0: item = template_string.format(src_text=src_text, tgt_text=tgt_text, src_prev=src_prev, src_next=src_next, tgt_prev=tgt_prev, tgt_next=tgt_next, src_lang=src_lang, tgt_lang=tgt_lang, template=output_template) dataset.append(create_conversation(item, json.dumps(output))) return dataset train_set = create_dataset(train_datapoints) test_set = create_dataset(test_datapoints) return train_set, test_set train_dataset, test_dataset = dataset_prep(datapoints) dataset = {"train": train_dataset, "test": test_dataset} def convert_to_hf_dataset(dataset): # Convert the train and test datasets into Hugging Face Dataset objects train_dataset = Dataset.from_list(dataset['train']) test_dataset = Dataset.from_list(dataset['test']) # Combine them into a DatasetDict hf_dataset = DatasetDict({ 'train': train_dataset, 'test': test_dataset }) return hf_dataset # Convert your dataset into a Hugging Face Dataset object hf_dataset = convert_to_hf_dataset(dataset) # Now you can use hf_dataset for your machine learning tasks print(hf_dataset) import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig from transformers import AutoProcessor, Gemma3ForConditionalGeneration device = torch.device("cuda:0") # Hugging Face model id model_id = "google/gemma-3-12b-it" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt` # Select model class based on id if model_id == "google/gemma-3-12b-it": model_class = Gemma3ForConditionalGeneration else: model_class = AutoModelForImageTextToText torch_dtype = torch.bfloat16 model_kwargs = dict( attn_implementation="eager", torch_dtype=torch_dtype, device_map="auto", # Change from {'': 0} to "auto" ) model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type='nf8', bnb_8bit_compute_dtype=model_kwargs['torch_dtype'], bnb_8bit_quant_storage=model_kwargs['torch_dtype'], ) model = model_class.from_pretrained(model_id, **model_kwargs) tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-12b-it") # Load the Instruction Tokenizer to use the official Gemma template from peft import LoraConfig peft_config = LoraConfig( lora_alpha=128, lora_dropout=0.05, r=16, bias="none", target_modules="all-linear", task_type="CAUSAL_LM", modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens ) from trl import SFTConfig args = SFTConfig( output_dir="gemma-12b-tq-model", max_seq_length=512, packing=True, num_train_epochs=1, per_device_train_batch_size=1, gradient_accumulation_steps=4, gradient_checkpointing=True, optim="adamw_torch_fused", logging_steps=1, save_strategy="epoch", learning_rate=2e-4, fp16=True if torch_dtype == torch.float16 else False, bf16=True if torch_dtype == torch.bfloat16 else False, max_grad_norm=0.3, warmup_ratio=0.03, lr_scheduler_type="constant", push_to_hub=True, report_to="tensorboard", dataset_kwargs={ "add_special_tokens": False, "append_concat_token": True, }, ddp_find_unused_parameters=False, no_cuda=False, ) from trl import SFTTrainer # Create Trainer object trainer = SFTTrainer( model=model, args=args, train_dataset=hf_dataset["train"], peft_config=peft_config, processing_class=tokenizer ) trainer.train() trainer.save_model()