MAP_EXP_18 / MAP_EXP_18.py
jaytonde05's picture
Upload 8 files
82d6a67 verified
import torch
import torch.nn as nn
import shutil
import numpy as np
import pandas as pd
import mlflow
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from datasets import Dataset
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorWithPadding,
BitsAndBytesConfig,
AutoModel
)
from peft import (
LoraConfig,
TaskType,
get_peft_model,
prepare_model_for_kbit_training,
)
from transformers.modeling_outputs import SequenceClassifierOutput
model_name = "MathGenie/MathCoder2-DeepSeekMath-7B"
MAX_LEN = 256
mlflow.set_tracking_uri("http://127.0.0.1:8081")
############################################################<-DATA->###########################################################
le_category = LabelEncoder()
le_misconception = LabelEncoder()
train = pd.read_csv('category_misconception_folds.csv')
train.Misconception = train.Misconception.fillna('NA')
train['category_label'] = le_category.fit_transform(train['Category'])
train['misconception_label'] = le_misconception.fit_transform(train['Misconception'])
train.to_excel("train_text.xlsx")
n_category_classes = len(le_category.classes_)
n_misconception_classes = len(le_misconception.classes_)
print(f"Train shape : {train.shape}")
print(f"Category classes : {n_category_classes}")
print(f"Misconception classes: {n_misconception_classes}")
print(f"Category classes names : {le_category.classes_}")
print(train[['Category', 'category_label', 'Misconception', 'misconception_label']].head())
idx = train.apply(lambda row: row.Category.split('_')[0], axis=1) == 'True'
correct = train.loc[idx].copy()
correct['c'] = correct.groupby(['QuestionId', 'MC_Answer']).MC_Answer.transform('count')
correct = correct.sort_values('c', ascending=False)
correct = correct.drop_duplicates(['QuestionId'])
correct = correct[['QuestionId', 'MC_Answer']]
correct['is_correct'] = 1
train = train.merge(correct, on=['QuestionId', 'MC_Answer'], how='left')
train.is_correct = train.is_correct.fillna(0)
def format_input(row):
x = "This answer is correct."
if not row['is_correct']:
x = "This is answer is incorrect."
return (
f"Question: {row['QuestionText']}\n"
f"Answer: {row['MC_Answer']}\n"
f"{x}\n"
f"Student Explanation: {row['StudentExplanation']}"
)
train['text'] = train.apply(format_input, axis=1)
train_df = train[train["fold"]==0]
val_df = train[train["fold"]==1]
COLS = ['text', 'category_label', 'misconception_label']
train_ds = Dataset.from_pandas(train_df[COLS])
val_ds = Dataset.from_pandas(val_df[COLS])
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
def tokenize_func(examples):
tokenized = tokenizer(
examples["text"],
add_special_tokens = True,
truncation = True,
max_length = MAX_LEN,
padding = False,
)
tokenized['category_label'] = examples['category_label']
tokenized['misconception_label'] = examples['misconception_label']
return tokenized
train_ds = train_ds.map(tokenize_func, batched=True, desc="Tokenizing train data")
val_ds = val_ds.map(tokenize_func, batched=True, desc="Tokenizing train data")
##########################################################<-END->###############################################################
############################################################<-MODEL->###########################################################
class MultiHeadClassificationModel(nn.Module):
def __init__(self, model_name, n_category_classes, n_misconception_classes, **model_kwargs):
super().__init__()
self.base_model = AutoModel.from_pretrained(model_name, **model_kwargs)
self.base_model.config.use_cache = False # Disable KV cache for training
self.base_model.config.output_hidden_states = False
self.base_model.config.output_attentions = False
self.config = self.base_model.config
hidden_size = self.base_model.config.hidden_size
self.category_head = nn.Linear(hidden_size, n_category_classes)
self.misconception_head = nn.Linear(hidden_size, n_misconception_classes)
self.n_category_classes = n_category_classes
self.n_misconception_classes = n_misconception_classes
self.alpha = 0.6
self.beta = 0.4
def forward(self, input_ids, attention_mask=None, category_label=None, misconception_label=None, combined_label=None, **kwargs):
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state.mean(dim=1)
category_logits = self.category_head(pooled)
misconception_logits = self.misconception_head(pooled)
loss = None
if category_label is not None and misconception_label is not None:
loss_fct = nn.CrossEntropyLoss(reduction='none')
category_loss_unreduced = loss_fct(category_logits, category_label)
misconception_loss_unreduced = loss_fct(misconception_logits, misconception_label)
# categories_with_subclasses = torch.tensor([1, 4], device=category_label.device)
# mask = torch.isin(category_label, categories_with_subclasses).float()
# misconception_loss_masked = misconception_loss_unreduced * mask
category_loss = torch.mean(category_loss_unreduced)
misconception_loss = torch.mean(misconception_loss_unreduced)
loss = self.alpha * category_loss + self.beta * misconception_loss
# if mask.any():
# print(f"got the samples of misconception. so misco loss is : {misconception_loss} and cat loss is : {category_loss} and final loss is : {loss}")
return SequenceClassifierOutput(
loss=loss,
logits=(category_logits, misconception_logits)
)
model_kwargs = dict(
trust_remote_code = True,
torch_dtype = torch.float16
)
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_use_double_quant = True,
bnb_4bit_compute_dtype = "float16",
)
print(f"Loading model : {model_name}")
model = MultiHeadClassificationModel(
model_name,
n_category_classes = n_category_classes,
n_misconception_classes = n_misconception_classes,
**model_kwargs
)
model.base_model.config.pad_token_id = tokenizer.pad_token_id
lora_config = LoraConfig(
r = 64,
lora_alpha = 64,
target_modules = "all-linear",
lora_dropout = 0.05,
bias = "none",
task_type = TaskType.SEQ_CLS,
modules_to_save = ["category_head", "misconception_head"],
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
print(f"Model Architecture : {model}")
##########################################################<-END->###############################################################
############################################################<-METRICS->###########################################################
def compute_multi_map(eval_pred, ks=[3, 5, 10]):
"""
Computes MAP@k and a detailed rank distribution for both category and misconception predictions.
This includes:
- Rank counts for rank 1, 2-3, and above 3.
- For rank groups 2-3 and above 3, it finds the top 3 most frequent
classes and calculates their average probability score.
"""
# 1. Unpack logits and labels
category_logits, misconception_logits = eval_pred.predictions
category_labels, misconception_labels = eval_pred.label_ids
category_labels = np.array(category_labels)
misconception_labels = np.array(misconception_labels)
# 2. Convert logits to probabilities
# The `probs` array has shape: (num_samples, num_classes)
category_probs = torch.nn.functional.softmax(torch.tensor(category_logits), dim=-1).numpy()
misconception_probs = torch.nn.functional.softmax(torch.tensor(misconception_logits), dim=-1).numpy()
print(f"category_probs : {category_probs}")
print(f"category_labels : {category_labels}")
print(f"misconception_probs : {misconception_probs}")
print(f"misconception_labels : {misconception_labels}")
# 3. Get top-k predictions
max_k = max(ks)
category_top_k_preds = np.argsort(-category_probs, axis=1)[:, :max_k]
misconception_top_k_preds = np.argsort(-misconception_probs, axis=1)[:, :max_k]
# 4. Create a boolean match array
category_match_array = (category_top_k_preds == category_labels[:, None])
misconception_match_array = (misconception_top_k_preds == misconception_labels[:, None])
# 5. Compute MAP@k for each specified k
metrics = {}
# Category MAP@k
for k in ks:
match_at_k = category_match_array[:, :k]
ranks = np.argmax(match_at_k, axis=1) + 1
has_match_at_k = np.any(match_at_k, axis=1)
scores = has_match_at_k * (1.0 / ranks)
metrics[f"map@{k}_category"] = np.mean(scores)
# Misconception MAP@k
for k in ks:
match_at_k = misconception_match_array[:, :k]
ranks = np.argmax(match_at_k, axis=1) + 1
has_match_at_k = np.any(match_at_k, axis=1)
scores = has_match_at_k * (1.0 / ranks)
metrics[f"map@{k}_misconception"] = np.mean(scores)
# 6. Calculate detailed rank position breakdown for CATEGORY
category_ranks_with_indices = [np.where(row)[0] for row in category_match_array]
category_correct_ranks = np.array([r[0] + 1 if len(r) > 0 else max_k + 1 for r in category_ranks_with_indices])
total = category_labels.shape[0]
metrics["category_rank_1"] = np.sum(category_correct_ranks == 1)
metrics["category_rank_2_to_3"] = np.sum((category_correct_ranks >= 2) & (category_correct_ranks <= 3))
metrics["category_rank_above_3"] = np.sum((category_correct_ranks > 3) & (category_correct_ranks <= max_k))
metrics["category_no_match_in_top_k"] = np.sum(category_correct_ranks > max_k)
metrics["category_total"] = total
# 7. Find top 3 classes for rank groups and their average probability - CATEGORY
# --- For category ranks 2 to 3 ---
category_rank_2_to_3_mask = (category_correct_ranks >= 2) & (category_correct_ranks <= 3)
category_rank_2_to_3_labels = category_labels[category_rank_2_to_3_mask]
if len(category_rank_2_to_3_labels) > 0:
top_classes = Counter(category_rank_2_to_3_labels).most_common(3)
augmented_top_classes = []
for cls, count in top_classes:
class_in_group_mask = (category_labels == cls) & category_rank_2_to_3_mask
class_probs = category_probs[class_in_group_mask, cls]
avg_prob = np.mean(class_probs)
augmented_top_classes.append((cls, count, round(float(avg_prob), 4)))
# metrics["category_rank_2_to_3_details"] = augmented_top_classes
# else:
# metrics["category_rank_2_to_3_details"] = []
# --- For category ranks above 3 (up to max_k) ---
category_rank_above_3_mask = (category_correct_ranks > 3) & (category_correct_ranks <= max_k)
category_rank_above_3_labels = category_labels[category_rank_above_3_mask]
if len(category_rank_above_3_labels) > 0:
top_classes = Counter(category_rank_above_3_labels).most_common(3)
augmented_top_classes = []
for cls, count in top_classes:
class_in_group_mask = (category_labels == cls) & category_rank_above_3_mask
class_probs = category_probs[class_in_group_mask, cls]
avg_prob = np.mean(class_probs)
augmented_top_classes.append((cls, count, round(float(avg_prob), 4)))
# metrics["category_rank_above_3_details"] = augmented_top_classes
# else:
# metrics["category_rank_above_3_details"] = []
# 8. Calculate detailed rank position breakdown for MISCONCEPTION
misconception_ranks_with_indices = [np.where(row)[0] for row in misconception_match_array]
misconception_correct_ranks = np.array([r[0] + 1 if len(r) > 0 else max_k + 1 for r in misconception_ranks_with_indices])
total = misconception_labels.shape[0]
metrics["misconception_rank_1"] = np.sum(misconception_correct_ranks == 1)
metrics["misconception_rank_2_to_3"] = np.sum((misconception_correct_ranks >= 2) & (misconception_correct_ranks <= 3))
metrics["misconception_rank_above_3"] = np.sum((misconception_correct_ranks > 3) & (misconception_correct_ranks <= max_k))
metrics["misconception_no_match_in_top_k"] = np.sum(misconception_correct_ranks > max_k)
metrics["misconception_total"] = total
# 9. Find top 3 classes for rank groups and their average probability - MISCONCEPTION
# --- For misconception ranks 2 to 3 ---
misconception_rank_2_to_3_mask = (misconception_correct_ranks >= 2) & (misconception_correct_ranks <= 3)
misconception_rank_2_to_3_labels = misconception_labels[misconception_rank_2_to_3_mask]
if len(misconception_rank_2_to_3_labels) > 0:
top_classes = Counter(misconception_rank_2_to_3_labels).most_common(3)
augmented_top_classes = []
for cls, count in top_classes:
class_in_group_mask = (misconception_labels == cls) & misconception_rank_2_to_3_mask
class_probs = misconception_probs[class_in_group_mask, cls]
avg_prob = np.mean(class_probs)
augmented_top_classes.append((cls, count, round(float(avg_prob), 4)))
# metrics["misconception_rank_2_to_3_details"] = augmented_top_classes
# else:
# metrics["misconception_rank_2_to_3_details"] = []
# --- For misconception ranks above 3 (up to max_k) ---
misconception_rank_above_3_mask = (misconception_correct_ranks > 3) & (misconception_correct_ranks <= max_k)
misconception_rank_above_3_labels = misconception_labels[misconception_rank_above_3_mask]
if len(misconception_rank_above_3_labels) > 0:
top_classes = Counter(misconception_rank_above_3_labels).most_common(3)
augmented_top_classes = []
for cls, count in top_classes:
class_in_group_mask = (misconception_labels == cls) & misconception_rank_above_3_mask
class_probs = misconception_probs[class_in_group_mask, cls]
avg_prob = np.mean(class_probs)
augmented_top_classes.append((cls, count, round(float(avg_prob), 4)))
#metrics["misconception_rank_above_3_details"] = augmented_top_classes
# else:
# metrics["misconception_rank_above_3_details"] = []
# 10. Log metrics to MLflow for both category and misconception
# Category metrics
mlflow.log_metric("category_rank_1", metrics["category_rank_1"])
mlflow.log_metric("category_rank_2_to_3", metrics["category_rank_2_to_3"])
mlflow.log_metric("category_rank_above_3", metrics["category_rank_above_3"])
mlflow.log_metric("category_no_match_in_top_k", metrics["category_no_match_in_top_k"])
# Misconception metrics
mlflow.log_metric("misconception_rank_1", metrics["misconception_rank_1"])
mlflow.log_metric("misconception_rank_2_to_3", metrics["misconception_rank_2_to_3"])
mlflow.log_metric("misconception_rank_above_3", metrics["misconception_rank_above_3"])
mlflow.log_metric("misconception_no_match_in_top_k", metrics["misconception_no_match_in_top_k"])
return metrics
##########################################################<-END->###############################################################
############################################################<-TRAINER->###########################################################
training_args = TrainingArguments(
output_dir = "MAP_EXP_18",
eval_strategy = "steps",
save_strategy = "no",
logging_strategy = "steps",
logging_steps = 100,
eval_steps = 500,
learning_rate = 1e-4,
per_device_train_batch_size = 16,
per_device_eval_batch_size = 32,
lr_scheduler_type = "cosine",
warmup_ratio = 0.05,
report_to = "mlflow",
group_by_length = True,
max_grad_norm = 1.0,
weight_decay = 0.01,
num_train_epochs = 2,
label_names = ['category_label', 'misconception_label']
)
trainer = Trainer(
model,
args = training_args,
train_dataset = train_ds,
eval_dataset = val_ds,
tokenizer = tokenizer,
compute_metrics = compute_multi_map,
data_collator = DataCollatorWithPadding(tokenizer)
)
##########################################################<-END->###############################################################
if __name__ == "__main__":
trainer.train()
trainer.save_model("MAP_EXP_18")
source_file = "MAP_EXP_18.py"
destination_directory = "MAP_EXP_18"
shutil.copy(source_file, destination_directory)
print(f"File '{source_file}' copied to '{destination_directory}'")
print("Training completed and model saved!")