| | import ast |
| | import json |
| | import os |
| | from datetime import datetime |
| |
|
| | import torch |
| | from datasets import Dataset |
| | from trl import SFTConfig, SFTTrainer |
| | from unsloth import FastLanguageModel |
| |
|
| | model_name = "unsloth/Llama-3.2-3B-Instruct" |
| | data_path = "/home/mshahidul/readctrl/data/finetuning_data/dataset_for_sft_support_check_list.json" |
| | test_size = 0.1 |
| | seed = 3407 |
| | max_seq_length = 2048 |
| | load_in_4bit = True |
| |
|
| |
|
| | def formatting_prompts_func(examples): |
| | convos = examples["conversations"] |
| | texts = [ |
| | tokenizer.apply_chat_template( |
| | convo, |
| | tokenize=False, |
| | add_generation_prompt=False, |
| | ).removeprefix("<|begin_of_text|>") |
| | for convo in convos |
| | ] |
| | return {"text": texts} |
| |
|
| |
|
| | def parse_label_array(raw_text): |
| | text = (raw_text or "").strip() |
| | if not text: |
| | return [] |
| |
|
| | if "```" in text: |
| | text = text.replace("```json", "").replace("```", "").strip() |
| |
|
| | start = text.find("[") |
| | end = text.rfind("]") |
| | if start != -1 and end != -1 and end > start: |
| | text = text[start : end + 1] |
| |
|
| | parsed = None |
| | for parser in (json.loads, ast.literal_eval): |
| | try: |
| | parsed = parser(text) |
| | break |
| | except Exception: |
| | continue |
| |
|
| | if not isinstance(parsed, list): |
| | return [] |
| |
|
| | normalized = [] |
| | for item in parsed: |
| | if not isinstance(item, str): |
| | normalized.append("not_supported") |
| | continue |
| | label = item.strip().lower().replace("-", "_").replace(" ", "_") |
| | if label not in {"supported", "not_supported"}: |
| | label = "not_supported" |
| | normalized.append(label) |
| | return normalized |
| |
|
| |
|
| | def extract_conversation_pair(conversations): |
| | user_prompt = "" |
| | gold_response = "" |
| | for message in conversations: |
| | role = message.get("role") or message.get("from") |
| | content = message.get("content", "") |
| | if role == "user" and not user_prompt: |
| | user_prompt = content |
| | elif role == "assistant" and not gold_response: |
| | gold_response = content |
| | return user_prompt, gold_response |
| |
|
| |
|
| | def generate_prediction(user_prompt): |
| | prompt = tokenizer.apply_chat_template( |
| | [{"role": "user", "content": user_prompt}], |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| | with torch.inference_mode(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=128, |
| | do_sample=False, |
| | temperature=0.0, |
| | use_cache=True, |
| | ) |
| | generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] |
| | return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() |
| |
|
| |
|
| | |
| | model, tokenizer = FastLanguageModel.from_pretrained( |
| | model_name=model_name, |
| | max_seq_length=max_seq_length, |
| | dtype=None, |
| | load_in_4bit=load_in_4bit, |
| | ) |
| |
|
| | |
| | model = FastLanguageModel.get_peft_model( |
| | model, |
| | r=16, |
| | target_modules=[ |
| | "q_proj", |
| | "k_proj", |
| | "v_proj", |
| | "o_proj", |
| | "gate_proj", |
| | "up_proj", |
| | "down_proj", |
| | ], |
| | lora_alpha=16, |
| | lora_dropout=0, |
| | bias="none", |
| | use_gradient_checkpointing="unsloth", |
| | random_state=seed, |
| | ) |
| |
|
| | |
| | with open(data_path, "r", encoding="utf-8") as f: |
| | raw_data = json.load(f) |
| |
|
| | raw_dataset = Dataset.from_list(raw_data) |
| | split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True) |
| | train_raw = split_dataset["train"] |
| | test_raw = split_dataset["test"] |
| | train_dataset = train_raw.map(formatting_prompts_func, batched=True) |
| |
|
| | |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | model_tag = model_name.split("/")[-1].replace(".", "_") |
| | model_save_dir = f"/home/mshahidul/readctrl_model/support_checking_vllm/{model_tag}" |
| | run_info_dir = os.path.join( |
| | "/home/mshahidul/readctrl/code/support_check/model_info", |
| | f"{model_tag}_{timestamp}", |
| | ) |
| | os.makedirs(model_save_dir, exist_ok=True) |
| | os.makedirs(run_info_dir, exist_ok=True) |
| |
|
| | |
| | trainer = SFTTrainer( |
| | model=model, |
| | tokenizer=tokenizer, |
| | train_dataset=train_dataset, |
| | dataset_text_field="text", |
| | max_seq_length=max_seq_length, |
| | args=SFTConfig( |
| | per_device_train_batch_size=2, |
| | gradient_accumulation_steps=4, |
| | warmup_steps=5, |
| | max_steps=30, |
| | learning_rate=2e-4, |
| | fp16=not torch.cuda.is_bf16_supported(), |
| | bf16=torch.cuda.is_bf16_supported(), |
| | logging_steps=1, |
| | optim="adamw_8bit", |
| | weight_decay=0.01, |
| | lr_scheduler_type="linear", |
| | seed=seed, |
| | output_dir=os.path.join(run_info_dir, "trainer_outputs"), |
| | report_to="none", |
| | ), |
| | ) |
| |
|
| | |
| | trainer.train() |
| |
|
| | |
| | model.save_pretrained_merged(model_save_dir, tokenizer, save_method="merged_16bit") |
| | tokenizer.save_pretrained(model_save_dir) |
| |
|
| | |
| | FastLanguageModel.for_inference(model) |
| | model.eval() |
| |
|
| | results = [] |
| | exact_match_correct = 0 |
| | label_correct = 0 |
| | label_total = 0 |
| | parsed_prediction_count = 0 |
| |
|
| | for idx, sample in enumerate(test_raw): |
| | conversations = sample.get("conversations", []) |
| | user_prompt, gold_text = extract_conversation_pair(conversations) |
| | if not user_prompt: |
| | continue |
| |
|
| | gold_labels = parse_label_array(gold_text) |
| | pred_text = generate_prediction(user_prompt) |
| | pred_labels = parse_label_array(pred_text) |
| |
|
| | if pred_labels: |
| | parsed_prediction_count += 1 |
| |
|
| | exact_match = bool(gold_labels) and pred_labels == gold_labels |
| | if exact_match: |
| | exact_match_correct += 1 |
| |
|
| | sample_label_correct = 0 |
| | for pos, gold_label in enumerate(gold_labels): |
| | if pos < len(pred_labels) and pred_labels[pos] == gold_label: |
| | sample_label_correct += 1 |
| |
|
| | label_correct += sample_label_correct |
| | label_total += len(gold_labels) |
| |
|
| | results.append( |
| | { |
| | "sample_index": idx, |
| | "gold_labels": gold_labels, |
| | "predicted_labels": pred_labels, |
| | "raw_prediction": pred_text, |
| | "exact_match": exact_match, |
| | "label_accuracy": ( |
| | sample_label_correct / len(gold_labels) if gold_labels else None |
| | ), |
| | } |
| | ) |
| |
|
| | total_samples = len(results) |
| | exact_match_accuracy = exact_match_correct / total_samples if total_samples else 0.0 |
| | label_accuracy = label_correct / label_total if label_total else 0.0 |
| |
|
| | accuracy_summary = { |
| | "model_name": model_name, |
| | "model_save_dir": model_save_dir, |
| | "run_info_dir": run_info_dir, |
| | "dataset_path": data_path, |
| | "seed": seed, |
| | "test_size": test_size, |
| | "test_samples_evaluated": total_samples, |
| | "parsed_prediction_count": parsed_prediction_count, |
| | "exact_match_accuracy": exact_match_accuracy, |
| | "label_accuracy": label_accuracy, |
| | "exact_match_correct": exact_match_correct, |
| | "label_correct": label_correct, |
| | "label_total": label_total, |
| | "timestamp": timestamp, |
| | } |
| |
|
| | predictions_path = os.path.join(run_info_dir, "test_inference.json") |
| | accuracy_path = os.path.join(run_info_dir, "test_accuracy.json") |
| |
|
| | with open(predictions_path, "w", encoding="utf-8") as f: |
| | json.dump(results, f, ensure_ascii=False, indent=2) |
| |
|
| | with open(accuracy_path, "w", encoding="utf-8") as f: |
| | json.dump(accuracy_summary, f, ensure_ascii=False, indent=2) |
| |
|
| | print(f"Saved merged model to: {model_save_dir}") |
| | print(f"Saved run info folder to: {run_info_dir}") |
| | print(f"Saved test inference to: {predictions_path}") |
| | print(f"Saved test accuracy to: {accuracy_path}") |
| | print(f"Exact match accuracy: {exact_match_accuracy:.4f}") |
| | print(f"Label accuracy: {label_accuracy:.4f}") |