import torch import torch.nn as nn import shutil import numpy as np import pandas as pd import mlflow from collections import Counter from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder from datasets import Dataset from transformers import ( AutoTokenizer, TrainingArguments, Trainer, DataCollatorWithPadding, BitsAndBytesConfig, AutoModel ) from peft import ( LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, ) from transformers.modeling_outputs import SequenceClassifierOutput model_name = "MathGenie/MathCoder2-DeepSeekMath-7B" MAX_LEN = 256 mlflow.set_tracking_uri("http://127.0.0.1:8081") ############################################################<-DATA->########################################################### le_category = LabelEncoder() le_misconception = LabelEncoder() train = pd.read_csv('category_misconception_folds.csv') train.Misconception = train.Misconception.fillna('NA') train['category_label'] = le_category.fit_transform(train['Category']) train['misconception_label'] = le_misconception.fit_transform(train['Misconception']) train.to_excel("train_text.xlsx") n_category_classes = len(le_category.classes_) n_misconception_classes = len(le_misconception.classes_) print(f"Train shape : {train.shape}") print(f"Category classes : {n_category_classes}") print(f"Misconception classes: {n_misconception_classes}") print(f"Category classes names : {le_category.classes_}") print(train[['Category', 'category_label', 'Misconception', 'misconception_label']].head()) idx = train.apply(lambda row: row.Category.split('_')[0], axis=1) == 'True' correct = train.loc[idx].copy() correct['c'] = correct.groupby(['QuestionId', 'MC_Answer']).MC_Answer.transform('count') correct = correct.sort_values('c', ascending=False) correct = correct.drop_duplicates(['QuestionId']) correct = correct[['QuestionId', 'MC_Answer']] correct['is_correct'] = 1 train = train.merge(correct, on=['QuestionId', 'MC_Answer'], how='left') train.is_correct = train.is_correct.fillna(0) def format_input(row): x = "This answer is correct." if not row['is_correct']: x = "This is answer is incorrect." return ( f"Question: {row['QuestionText']}\n" f"Answer: {row['MC_Answer']}\n" f"{x}\n" f"Student Explanation: {row['StudentExplanation']}" ) train['text'] = train.apply(format_input, axis=1) train_df = train[train["fold"]==0] val_df = train[train["fold"]==1] COLS = ['text', 'category_label', 'misconception_label'] train_ds = Dataset.from_pandas(train_df[COLS]) val_ds = Dataset.from_pandas(val_df[COLS]) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" def tokenize_func(examples): tokenized = tokenizer( examples["text"], add_special_tokens = True, truncation = True, max_length = MAX_LEN, padding = False, ) tokenized['category_label'] = examples['category_label'] tokenized['misconception_label'] = examples['misconception_label'] return tokenized train_ds = train_ds.map(tokenize_func, batched=True, desc="Tokenizing train data") val_ds = val_ds.map(tokenize_func, batched=True, desc="Tokenizing train data") ##########################################################<-END->############################################################### ############################################################<-MODEL->########################################################### class MultiHeadClassificationModel(nn.Module): def __init__(self, model_name, n_category_classes, n_misconception_classes, **model_kwargs): super().__init__() self.base_model = AutoModel.from_pretrained(model_name, **model_kwargs) self.base_model.config.use_cache = False # Disable KV cache for training self.base_model.config.output_hidden_states = False self.base_model.config.output_attentions = False self.config = self.base_model.config hidden_size = self.base_model.config.hidden_size self.category_head = nn.Linear(hidden_size, n_category_classes) self.misconception_head = nn.Linear(hidden_size, n_misconception_classes) self.n_category_classes = n_category_classes self.n_misconception_classes = n_misconception_classes self.alpha = 0.6 self.beta = 0.4 def forward(self, input_ids, attention_mask=None, category_label=None, misconception_label=None, combined_label=None, **kwargs): outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) pooled = outputs.last_hidden_state.mean(dim=1) category_logits = self.category_head(pooled) misconception_logits = self.misconception_head(pooled) loss = None if category_label is not None and misconception_label is not None: loss_fct = nn.CrossEntropyLoss(reduction='none') category_loss_unreduced = loss_fct(category_logits, category_label) misconception_loss_unreduced = loss_fct(misconception_logits, misconception_label) # categories_with_subclasses = torch.tensor([1, 4], device=category_label.device) # mask = torch.isin(category_label, categories_with_subclasses).float() # misconception_loss_masked = misconception_loss_unreduced * mask category_loss = torch.mean(category_loss_unreduced) misconception_loss = torch.mean(misconception_loss_unreduced) loss = self.alpha * category_loss + self.beta * misconception_loss # if mask.any(): # print(f"got the samples of misconception. so misco loss is : {misconception_loss} and cat loss is : {category_loss} and final loss is : {loss}") return SequenceClassifierOutput( loss=loss, logits=(category_logits, misconception_logits) ) model_kwargs = dict( trust_remote_code = True, torch_dtype = torch.float16 ) model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit = True, bnb_4bit_quant_type = "nf4", bnb_4bit_use_double_quant = True, bnb_4bit_compute_dtype = "float16", ) print(f"Loading model : {model_name}") model = MultiHeadClassificationModel( model_name, n_category_classes = n_category_classes, n_misconception_classes = n_misconception_classes, **model_kwargs ) model.base_model.config.pad_token_id = tokenizer.pad_token_id lora_config = LoraConfig( r = 64, lora_alpha = 64, target_modules = "all-linear", lora_dropout = 0.05, bias = "none", task_type = TaskType.SEQ_CLS, modules_to_save = ["category_head", "misconception_head"], ) model = prepare_model_for_kbit_training(model) model = get_peft_model(model, lora_config) model.print_trainable_parameters() print(f"Model Architecture : {model}") ##########################################################<-END->############################################################### ############################################################<-METRICS->########################################################### def compute_multi_map(eval_pred, ks=[3, 5, 10]): """ Computes MAP@k and a detailed rank distribution for both category and misconception predictions. This includes: - Rank counts for rank 1, 2-3, and above 3. - For rank groups 2-3 and above 3, it finds the top 3 most frequent classes and calculates their average probability score. """ # 1. Unpack logits and labels category_logits, misconception_logits = eval_pred.predictions category_labels, misconception_labels = eval_pred.label_ids category_labels = np.array(category_labels) misconception_labels = np.array(misconception_labels) # 2. Convert logits to probabilities # The `probs` array has shape: (num_samples, num_classes) category_probs = torch.nn.functional.softmax(torch.tensor(category_logits), dim=-1).numpy() misconception_probs = torch.nn.functional.softmax(torch.tensor(misconception_logits), dim=-1).numpy() print(f"category_probs : {category_probs}") print(f"category_labels : {category_labels}") print(f"misconception_probs : {misconception_probs}") print(f"misconception_labels : {misconception_labels}") # 3. Get top-k predictions max_k = max(ks) category_top_k_preds = np.argsort(-category_probs, axis=1)[:, :max_k] misconception_top_k_preds = np.argsort(-misconception_probs, axis=1)[:, :max_k] # 4. Create a boolean match array category_match_array = (category_top_k_preds == category_labels[:, None]) misconception_match_array = (misconception_top_k_preds == misconception_labels[:, None]) # 5. Compute MAP@k for each specified k metrics = {} # Category MAP@k for k in ks: match_at_k = category_match_array[:, :k] ranks = np.argmax(match_at_k, axis=1) + 1 has_match_at_k = np.any(match_at_k, axis=1) scores = has_match_at_k * (1.0 / ranks) metrics[f"map@{k}_category"] = np.mean(scores) # Misconception MAP@k for k in ks: match_at_k = misconception_match_array[:, :k] ranks = np.argmax(match_at_k, axis=1) + 1 has_match_at_k = np.any(match_at_k, axis=1) scores = has_match_at_k * (1.0 / ranks) metrics[f"map@{k}_misconception"] = np.mean(scores) # 6. Calculate detailed rank position breakdown for CATEGORY category_ranks_with_indices = [np.where(row)[0] for row in category_match_array] category_correct_ranks = np.array([r[0] + 1 if len(r) > 0 else max_k + 1 for r in category_ranks_with_indices]) total = category_labels.shape[0] metrics["category_rank_1"] = np.sum(category_correct_ranks == 1) metrics["category_rank_2_to_3"] = np.sum((category_correct_ranks >= 2) & (category_correct_ranks <= 3)) metrics["category_rank_above_3"] = np.sum((category_correct_ranks > 3) & (category_correct_ranks <= max_k)) metrics["category_no_match_in_top_k"] = np.sum(category_correct_ranks > max_k) metrics["category_total"] = total # 7. Find top 3 classes for rank groups and their average probability - CATEGORY # --- For category ranks 2 to 3 --- category_rank_2_to_3_mask = (category_correct_ranks >= 2) & (category_correct_ranks <= 3) category_rank_2_to_3_labels = category_labels[category_rank_2_to_3_mask] if len(category_rank_2_to_3_labels) > 0: top_classes = Counter(category_rank_2_to_3_labels).most_common(3) augmented_top_classes = [] for cls, count in top_classes: class_in_group_mask = (category_labels == cls) & category_rank_2_to_3_mask class_probs = category_probs[class_in_group_mask, cls] avg_prob = np.mean(class_probs) augmented_top_classes.append((cls, count, round(float(avg_prob), 4))) # metrics["category_rank_2_to_3_details"] = augmented_top_classes # else: # metrics["category_rank_2_to_3_details"] = [] # --- For category ranks above 3 (up to max_k) --- category_rank_above_3_mask = (category_correct_ranks > 3) & (category_correct_ranks <= max_k) category_rank_above_3_labels = category_labels[category_rank_above_3_mask] if len(category_rank_above_3_labels) > 0: top_classes = Counter(category_rank_above_3_labels).most_common(3) augmented_top_classes = [] for cls, count in top_classes: class_in_group_mask = (category_labels == cls) & category_rank_above_3_mask class_probs = category_probs[class_in_group_mask, cls] avg_prob = np.mean(class_probs) augmented_top_classes.append((cls, count, round(float(avg_prob), 4))) # metrics["category_rank_above_3_details"] = augmented_top_classes # else: # metrics["category_rank_above_3_details"] = [] # 8. Calculate detailed rank position breakdown for MISCONCEPTION misconception_ranks_with_indices = [np.where(row)[0] for row in misconception_match_array] misconception_correct_ranks = np.array([r[0] + 1 if len(r) > 0 else max_k + 1 for r in misconception_ranks_with_indices]) total = misconception_labels.shape[0] metrics["misconception_rank_1"] = np.sum(misconception_correct_ranks == 1) metrics["misconception_rank_2_to_3"] = np.sum((misconception_correct_ranks >= 2) & (misconception_correct_ranks <= 3)) metrics["misconception_rank_above_3"] = np.sum((misconception_correct_ranks > 3) & (misconception_correct_ranks <= max_k)) metrics["misconception_no_match_in_top_k"] = np.sum(misconception_correct_ranks > max_k) metrics["misconception_total"] = total # 9. Find top 3 classes for rank groups and their average probability - MISCONCEPTION # --- For misconception ranks 2 to 3 --- misconception_rank_2_to_3_mask = (misconception_correct_ranks >= 2) & (misconception_correct_ranks <= 3) misconception_rank_2_to_3_labels = misconception_labels[misconception_rank_2_to_3_mask] if len(misconception_rank_2_to_3_labels) > 0: top_classes = Counter(misconception_rank_2_to_3_labels).most_common(3) augmented_top_classes = [] for cls, count in top_classes: class_in_group_mask = (misconception_labels == cls) & misconception_rank_2_to_3_mask class_probs = misconception_probs[class_in_group_mask, cls] avg_prob = np.mean(class_probs) augmented_top_classes.append((cls, count, round(float(avg_prob), 4))) # metrics["misconception_rank_2_to_3_details"] = augmented_top_classes # else: # metrics["misconception_rank_2_to_3_details"] = [] # --- For misconception ranks above 3 (up to max_k) --- misconception_rank_above_3_mask = (misconception_correct_ranks > 3) & (misconception_correct_ranks <= max_k) misconception_rank_above_3_labels = misconception_labels[misconception_rank_above_3_mask] if len(misconception_rank_above_3_labels) > 0: top_classes = Counter(misconception_rank_above_3_labels).most_common(3) augmented_top_classes = [] for cls, count in top_classes: class_in_group_mask = (misconception_labels == cls) & misconception_rank_above_3_mask class_probs = misconception_probs[class_in_group_mask, cls] avg_prob = np.mean(class_probs) augmented_top_classes.append((cls, count, round(float(avg_prob), 4))) #metrics["misconception_rank_above_3_details"] = augmented_top_classes # else: # metrics["misconception_rank_above_3_details"] = [] # 10. Log metrics to MLflow for both category and misconception # Category metrics mlflow.log_metric("category_rank_1", metrics["category_rank_1"]) mlflow.log_metric("category_rank_2_to_3", metrics["category_rank_2_to_3"]) mlflow.log_metric("category_rank_above_3", metrics["category_rank_above_3"]) mlflow.log_metric("category_no_match_in_top_k", metrics["category_no_match_in_top_k"]) # Misconception metrics mlflow.log_metric("misconception_rank_1", metrics["misconception_rank_1"]) mlflow.log_metric("misconception_rank_2_to_3", metrics["misconception_rank_2_to_3"]) mlflow.log_metric("misconception_rank_above_3", metrics["misconception_rank_above_3"]) mlflow.log_metric("misconception_no_match_in_top_k", metrics["misconception_no_match_in_top_k"]) return metrics ##########################################################<-END->############################################################### ############################################################<-TRAINER->########################################################### training_args = TrainingArguments( output_dir = "MAP_EXP_18", eval_strategy = "steps", save_strategy = "no", logging_strategy = "steps", logging_steps = 100, eval_steps = 500, learning_rate = 1e-4, per_device_train_batch_size = 16, per_device_eval_batch_size = 32, lr_scheduler_type = "cosine", warmup_ratio = 0.05, report_to = "mlflow", group_by_length = True, max_grad_norm = 1.0, weight_decay = 0.01, num_train_epochs = 2, label_names = ['category_label', 'misconception_label'] ) trainer = Trainer( model, args = training_args, train_dataset = train_ds, eval_dataset = val_ds, tokenizer = tokenizer, compute_metrics = compute_multi_map, data_collator = DataCollatorWithPadding(tokenizer) ) ##########################################################<-END->############################################################### if __name__ == "__main__": trainer.train() trainer.save_model("MAP_EXP_18") source_file = "MAP_EXP_18.py" destination_directory = "MAP_EXP_18" shutil.copy(source_file, destination_directory) print(f"File '{source_file}' copied to '{destination_directory}'") print("Training completed and model saved!")