| import ast |
| import json |
| import os |
| import sys |
| from datetime import datetime |
|
|
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| from unsloth import FastLanguageModel |
| import torch |
| model_name = "unsloth/Qwen3-8B" |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name = model_name, |
| max_seq_length = 8192, |
| load_in_4bit = False, |
| load_in_8bit = False, |
| full_finetuning = False, |
| |
| ) |
| model = FastLanguageModel.get_peft_model( |
| model, |
| r = 32, |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj",], |
| lora_alpha = 32, |
| lora_dropout = 0, |
| bias = "none", |
| |
| use_gradient_checkpointing = "unsloth", |
| random_state = 3407, |
| use_rslora = False, |
| loftq_config = None, |
| ) |
|
|
| with open(f"/home/mshahidul/readctrl/data/finetuning_data/dataset_for_sft_support_check_list.json") as f: |
| data = json.load(f) |
| from datasets import Dataset |
| dataset = Dataset.from_list(data) |
|
|
| from unsloth.chat_templates import standardize_sharegpt |
| dataset = standardize_sharegpt(dataset) |
|
|
| def formatting_prompts_func(examples): |
| convos = examples["conversations"] |
| texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos] |
| return { "text" : texts, } |
|
|
|
|
| def parse_label_array(raw_text): |
| text = (raw_text or "").strip() |
| if not text: |
| return [] |
|
|
| if "```" in text: |
| text = text.replace("```json", "").replace("```", "").strip() |
|
|
| start = text.find("[") |
| end = text.rfind("]") |
| if start != -1 and end != -1 and end > start: |
| text = text[start : end + 1] |
|
|
| parsed = None |
| for parser in (json.loads, ast.literal_eval): |
| try: |
| parsed = parser(text) |
| break |
| except Exception: |
| continue |
|
|
| if not isinstance(parsed, list): |
| return [] |
|
|
| normalized = [] |
| for item in parsed: |
| if not isinstance(item, str): |
| normalized.append("not_supported") |
| continue |
| label = item.strip().lower().replace("-", "_").replace(" ", "_") |
| if label not in {"supported", "not_supported"}: |
| label = "not_supported" |
| normalized.append(label) |
| return normalized |
|
|
|
|
| def extract_conversation_pair(conversations): |
| user_prompt = "" |
| gold_response = "" |
| for message in conversations: |
| role = message.get("role") or message.get("from") |
| content = message.get("content", "") |
| if role == "user" and not user_prompt: |
| user_prompt = content |
| elif role == "assistant" and not gold_response: |
| gold_response = content |
| return user_prompt, gold_response |
|
|
|
|
| def generate_prediction(user_prompt): |
| prompt = tokenizer.apply_chat_template( |
| [{"role": "user", "content": user_prompt}], |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| with torch.inference_mode(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=128, |
| do_sample=False, |
| temperature=0.0, |
| use_cache=True, |
| ) |
| generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] |
| return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() |
|
|
| dataset = dataset.map(formatting_prompts_func, batched = True) |
|
|
| split_dataset = dataset.train_test_split(test_size = 0.1, seed = 3407, shuffle = True) |
| train_dataset = split_dataset["train"] |
| eval_dataset = split_dataset["test"] |
|
|
| from trl import SFTTrainer, SFTConfig |
| trainer = SFTTrainer( |
| model = model, |
| tokenizer = tokenizer, |
| train_dataset = train_dataset, |
| eval_dataset = eval_dataset, |
| args = SFTConfig( |
| dataset_text_field = "text", |
| per_device_train_batch_size = 8, |
| gradient_accumulation_steps = 2, |
| warmup_steps = 5, |
| num_train_epochs = 3, |
| |
| learning_rate = 2e-4, |
| logging_steps = 1, |
| per_device_eval_batch_size = 8, |
| bf16 = True, |
| tf32 = True, |
| optim = "adamw_8bit", |
| weight_decay = 0.01, |
| lr_scheduler_type = "linear", |
| seed = 3407, |
| report_to = "none", |
| ), |
| ) |
| trainer_stats = trainer.train() |
|
|
| save_dir = f"/home/mshahidul/readctrl_model/support_checking_vllm/{model_name.split('/')[-1]}" |
| os.makedirs(save_dir, exist_ok=True) |
| |
| model.save_pretrained_merged( |
| save_dir, |
| tokenizer, |
| save_method = "merged_16bit", |
| ) |
| tokenizer.save_pretrained(save_dir) |
|
|
| FastLanguageModel.for_inference(model) |
| model.eval() |
|
|
| model_info_dir = "/home/mshahidul/readctrl/code/support_check/model_info" |
| os.makedirs(model_info_dir, exist_ok=True) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| model_tag = model_name.split("/")[-1].replace(".", "_") |
|
|
| results = [] |
| exact_match_correct = 0 |
| label_correct = 0 |
| label_total = 0 |
| parsed_prediction_count = 0 |
|
|
| for idx, sample in enumerate(eval_dataset): |
| conversations = sample.get("conversations", []) |
| user_prompt, gold_text = extract_conversation_pair(conversations) |
| if not user_prompt: |
| continue |
|
|
| gold_labels = parse_label_array(gold_text) |
| pred_text = generate_prediction(user_prompt) |
| pred_labels = parse_label_array(pred_text) |
|
|
| if pred_labels: |
| parsed_prediction_count += 1 |
|
|
| exact_match = bool(gold_labels) and pred_labels == gold_labels |
| if exact_match: |
| exact_match_correct += 1 |
|
|
| sample_label_correct = 0 |
| for pos, gold_label in enumerate(gold_labels): |
| if pos < len(pred_labels) and pred_labels[pos] == gold_label: |
| sample_label_correct += 1 |
|
|
| label_correct += sample_label_correct |
| label_total += len(gold_labels) |
|
|
| results.append( |
| { |
| "sample_index": idx, |
| "gold_labels": gold_labels, |
| "predicted_labels": pred_labels, |
| "raw_prediction": pred_text, |
| "exact_match": exact_match, |
| "label_accuracy": ( |
| sample_label_correct / len(gold_labels) if gold_labels else None |
| ), |
| } |
| ) |
|
|
| total_samples = len(results) |
| exact_match_accuracy = exact_match_correct / total_samples if total_samples else 0.0 |
| label_accuracy = label_correct / label_total if label_total else 0.0 |
|
|
| accuracy_summary = { |
| "model_name": model_name, |
| "model_save_dir": save_dir, |
| "dataset_path": "/home/mshahidul/readctrl/data/finetuning_data/dataset_for_sft_support_check_list.json", |
| "seed": 3407, |
| "test_size": 0.1, |
| "test_samples_evaluated": total_samples, |
| "parsed_prediction_count": parsed_prediction_count, |
| "exact_match_accuracy": exact_match_accuracy, |
| "label_accuracy": label_accuracy, |
| "exact_match_correct": exact_match_correct, |
| "label_correct": label_correct, |
| "label_total": label_total, |
| "timestamp": timestamp, |
| } |
|
|
| predictions_path = os.path.join( |
| model_info_dir, |
| f"{model_tag}_test_inference_{timestamp}.json", |
| ) |
| accuracy_path = os.path.join( |
| model_info_dir, |
| f"{model_tag}_test_accuracy_{timestamp}.json", |
| ) |
|
|
| with open(predictions_path, "w", encoding="utf-8") as f: |
| json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
| with open(accuracy_path, "w", encoding="utf-8") as f: |
| json.dump(accuracy_summary, f, ensure_ascii=False, indent=2) |
|
|
| print(f"Saved test inference to: {predictions_path}") |
| print(f"Saved test accuracy to: {accuracy_path}") |
| print(f"Exact match accuracy: {exact_match_accuracy:.4f}") |
| print(f"Label accuracy: {label_accuracy:.4f}") |
|
|
| |
| |
| |
|
|
|
|