| """ |
| Base module for Knowledge Tracing LLM inference. |
| |
| This module contains all shared logic for running KT inference with different models. |
| Each model script imports this and provides model-specific configuration. |
| |
| Usage in model scripts: |
| from kt_inference_base import run_inference |
| |
| MODEL_CONFIG = { |
| "model_id": "model/name", |
| "gen_configs": {...}, |
| "output_prefix": "prefix", |
| "system_prompt_prefix": "", # e.g., "Reasoning: medium\n\n" |
| } |
| |
| if __name__ == "__main__": |
| run_inference(MODEL_CONFIG) |
| """ |
|
|
| import argparse |
| import contextlib |
| import os |
| from vllm import LLM, SamplingParams |
| import pandas as pd |
| import gc |
| import torch |
| from vllm.distributed.parallel_state import ( |
| destroy_model_parallel, |
| destroy_distributed_environment, |
| ) |
| import json |
| import re |
| import numpy as np |
| from tqdm import tqdm |
| from multiprocessing import Pool, cpu_count |
| from clean_utils import clean_problem_body |
| from cleantext import clean_text as clean_text_legacy |
|
|
|
|
| class NumpyEncoder(json.JSONEncoder): |
| """Custom JSON encoder that handles numpy types.""" |
| def default(self, obj): |
| if isinstance(obj, np.integer): |
| return int(obj) |
| if isinstance(obj, np.floating): |
| return float(obj) |
| if isinstance(obj, np.ndarray): |
| return obj.tolist() |
| return super().default(obj) |
|
|
|
|
| |
| DEFAULT_BATCH_SIZE = 10000 |
| DEFAULT_NUM_STUDENTS = 500 |
| DEFAULT_BIN_SIZE = 50 |
| DEFAULT_MIN_HISTORY = 50 |
|
|
| |
| STUDENT_FILE = "Interactions.csv" |
| PROBLEMS_FILE = "Problems.csv" |
| SKILL_FILE = "Skills.csv" |
|
|
| |
| BASE_SYSTEM_PROMPT = """You are a reasoning model trained to simulate a student's evolving knowledge and response behavior in mathematics. |
| |
| Your goal is to infer, from past problem–answer pairs, how this same student is likely to perform on a new problem — at multiple levels of granularity. |
| |
| You must reason about the student's learning progression, skill mastery, and recurring misconceptions, then produce structured predictions for the new item. |
| |
| --- |
| |
| Your Task: |
| |
| Generate three coordinated predictions for this student: |
| |
| 1) **Skill-level knowledge tracing (0 or 1):** Whether the student has mastered the underlying skill involved in the new problem. |
| 2) **Question-level knowledge tracing (0 or 1):** Whether the student will answer this specific problem correctly. |
| 3) **Cognitive-level prediction (string):** The exact answer text or option the student would most likely produce, written in their own response style. |
| |
| --- |
| |
| Reasoning Guidelines: |
| |
| - Use the student's historical data (problems, answers, hints, timestamps) to infer learning and forgetting patterns. |
| - Consider recency and exposure: later timestamps often indicate updated knowledge. |
| - Treat `UsedHint=True` or `SawAnswer=True` as evidence that the student's recorded answer may not reflect true mastery — they might have seen or been helped toward the solution. |
| - Attend to how the student's accuracy, style, and misconceptions evolve over time. |
| - You may think step-by-step internally, but your final output must follow the format below. |
| --- |
| |
| Output Format: |
| |
| When you are done reasoning, **finish your response with** the JSON object in this exact structure: |
| |
| For Multiple Choice (select 1) problems: |
| { |
| "skill_level": 0 or 1, |
| "question_level": 0 or 1, |
| "student_answer": "A" (single letter only) |
| } |
| |
| For Multiple Choice (select all) problems: |
| { |
| "skill_level": 0 or 1, |
| "question_level": 0 or 1, |
| "student_answer": "A, C" (comma-separated letters if multiple selections) |
| } |
| |
| For Fill-in problems: |
| { |
| "skill_level": 0 or 1, |
| "question_level": 0 or 1, |
| "student_answer": "<string exactly as this student would write (e.g., 'x=3', '3/5', '12')>" |
| } |
| |
| Predictions must be consistent. If you predict question_level to be 1, then student_answer must match the correct answer. If you predict question_level to be 0, student_answer must not match the correct answer.""" |
|
|
|
|
| def parse_args(default_output_jsonl): |
| """Parse command line arguments.""" |
| parser = argparse.ArgumentParser(description="Knowledge Tracing with LLM") |
| parser.add_argument( |
| "--batch-size", "-b", |
| type=int, |
| default=DEFAULT_BATCH_SIZE, |
| help=f"Batch size for LLM inference (default: {DEFAULT_BATCH_SIZE})" |
| ) |
| parser.add_argument( |
| "--output", "-o", |
| type=str, |
| default=None, |
| help="Output JSONL file path (overrides auto-generated name)" |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default=".", |
| help="Output directory for results (default: current directory)" |
| ) |
| parser.add_argument( |
| "--data-dir", "-d", |
| type=str, |
| default=".", |
| help="Directory containing input CSV files (default: current directory)" |
| ) |
| parser.add_argument( |
| "--cache-dir", "-c", |
| type=str, |
| default=None, |
| help="Directory for vLLM model cache (default: vLLM default)" |
| ) |
| parser.add_argument( |
| "--num-students", "-n", |
| type=int, |
| default=DEFAULT_NUM_STUDENTS, |
| help=f"Number of students to sample (default: {DEFAULT_NUM_STUDENTS}, use 0 or -1 for all students)" |
| ) |
| parser.add_argument( |
| "--bin-size", |
| type=int, |
| default=DEFAULT_BIN_SIZE, |
| help=f"Size of each prediction bin (default: {DEFAULT_BIN_SIZE})" |
| ) |
| parser.add_argument( |
| "--min-history", |
| type=int, |
| default=DEFAULT_MIN_HISTORY, |
| help=f"Minimum history size before making predictions (default: {DEFAULT_MIN_HISTORY})" |
| ) |
| parser.add_argument( |
| "--num-gpus", |
| type=int, |
| default=1, |
| help="Number of GPUs for tensor parallelism (default: 1)" |
| ) |
| parser.add_argument( |
| "--max-num-seqs", |
| type=int, |
| default=None, |
| help="Maximum number of sequences to process in a batch (vLLM, default: 256)" |
| ) |
| parser.add_argument( |
| "--reasoning-level", |
| type=str, |
| choices=["none", "low", "medium", "high"], |
| default=None, |
| help="Reasoning level for GPT-OSS models only. Default: uses model config (medium for GPT-OSS, none for Qwen)" |
| ) |
| parser.add_argument( |
| "--max-model-len", |
| type=int, |
| default=None, |
| help="Maximum sequence length in tokens (vLLM, default: model's context length)" |
| ) |
| parser.add_argument( |
| "--gpu-memory-utilization", |
| type=float, |
| default=0.9, |
| help="Fraction of GPU memory to use (vLLM, default: 0.9, range: 0.0-1.0)" |
| ) |
| parser.add_argument( |
| "--legacy-clean", |
| action="store_true", |
| default=False, |
| help="Use legacy text cleaner (cleantext.py) instead of clean_utils.py" |
| ) |
| return parser.parse_args() |
|
|
|
|
| def label_answer_options(answer_string): |
| """ |
| Convert pipe-delimited answers to lettered format. |
| Input: "Han is correct || Elena is correct || Both are correct" |
| Output: {"A": "Han is correct", "B": "Elena is correct", "C": "Both are correct"} |
| """ |
| if pd.isna(answer_string) or answer_string == '': |
| return None |
|
|
| options = [opt.strip() for opt in answer_string.split('||')] |
| letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'] |
| return {letters[i]: opt for i, opt in enumerate(options) if i < len(letters)} |
|
|
|
|
| def clean_html_and_normalize(text): |
| """ |
| Remove HTML tags and normalize text for comparison. |
| """ |
| if pd.isna(text): |
| return "" |
| |
| text = re.sub(r'<[^>]+>', '', str(text)) |
| |
| text = ' '.join(text.split()) |
| |
| text = re.sub(r'\s*:\s*', ':', text) |
| return text.strip() |
|
|
|
|
| def match_student_answer_to_letters(student_answer_text, answer_options_dict): |
| """ |
| Match student's comma-delineated answers to letter options. |
| |
| Args: |
| student_answer_text: String like "Answer A text , Answer C text , Answer B text" |
| answer_options_dict: Dict like {"A": "Answer A text", "B": "Answer B text", ...} |
| |
| Returns: |
| String like "A, B, C" or original text if no match |
| """ |
| if pd.isna(student_answer_text) or not answer_options_dict: |
| return student_answer_text |
|
|
| |
| student_answers = [ans.strip() for ans in str(student_answer_text).split(' , ')] |
|
|
| |
| normalized_options = { |
| letter: clean_html_and_normalize(text) |
| for letter, text in answer_options_dict.items() |
| } |
|
|
| matched_letters = [] |
| for student_ans in student_answers: |
| normalized_student = clean_html_and_normalize(student_ans) |
|
|
| |
| for letter, normalized_option in normalized_options.items(): |
| if normalized_student == normalized_option: |
| matched_letters.append(letter) |
| break |
| else: |
| |
| for letter, normalized_option in normalized_options.items(): |
| if (normalized_student in normalized_option or |
| normalized_option in normalized_student): |
| matched_letters.append(letter) |
| break |
|
|
| |
| if matched_letters: |
| return ', '.join(sorted(set(matched_letters))) |
| return student_answer_text |
|
|
|
|
| def get_correct_option_letters(answer_options, correct_answers): |
| """ |
| Determine which letter(s) correspond to correct answer(s). |
| |
| Args: |
| answer_options: Dict like {"A": "Han is correct", "B": "Elena is correct", ...} |
| correct_answers: String like "Both are correct" or "Han is correct || Elena is correct" |
| |
| Returns: |
| String like "C" or "A, B" depending on how many correct options |
| """ |
| if not answer_options or pd.isna(correct_answers): |
| return correct_answers |
|
|
| |
| correct_list = [ans.strip() for ans in correct_answers.split('||')] |
|
|
| |
| correct_letters = [] |
| for letter, text in answer_options.items(): |
| if text in correct_list: |
| correct_letters.append(letter) |
|
|
| return ', '.join(sorted(correct_letters)) if correct_letters else correct_answers |
|
|
|
|
| def format_answer_options_for_prompt(answer_options): |
| """ |
| Format answer options dictionary for display in prompt. |
| Input: {"A": "Han is correct", "B": "Elena is correct", ...} |
| Output: "A) Han is correct\nB) Elena is correct\n..." |
| """ |
| if not answer_options: |
| return None |
|
|
| return '\n'.join([f"{letter}) {text}" for letter, text in answer_options.items()]) |
|
|
|
|
| def create_user_prompt(student_history, new_problem, problem_df): |
| """ |
| Creates a user prompt with student history and new problem. |
| |
| Args: |
| student_history: List of dicts with keys: problem_id, timestamp, problem_text, |
| correct_answer, student_answer, used_hint, saw_answer |
| new_problem: Dict with keys: problem_text, correct_answer, used_hint, saw_answer, |
| answer_options (optional) |
| """ |
| prompt = "Task Description:\n\n" |
| prompt += "Your task is to model a single student's learning process and predict how they will respond to a new mathematics problem based on their prior work.\n\n" |
|
|
| prompt += """You will produce three coordinated predictions: |
| |
| 1) **Skill-level knowledge tracing (0 or 1):** Predict whether this student has mastered the underlying skill involved in the new problem. |
| 2) **Question-level knowledge tracing (0 or 1):** Predict whether this student will answer this specific problem correctly. |
| 3) **Cognitive-level prediction (string):** Generate the exact answer the student would most likely produce. |
| - For Multiple Choice (select 1): Predict a single letter (e.g., "A" or "B") |
| - For Multiple Choice (select all): Predict comma-separated letters (e.g., "A, C" or "B, D") |
| - For Fill-in problems: Predict the exact text the student would write |
| """ |
|
|
| prompt += """--- |
| |
| Provided Data: |
| |
| You will receive: |
| - ProblemID: <id> |
| - Timestamp: <timestamp> |
| - Problem: <problem text> |
| - Problem Type: Multiple Choice (select 1) / Multiple Choice (select all) / Fill-in Problem |
| - Options: Answer choices in format "A) ...\nB) ...\nC) ..." |
| - Correct Answer(s): The letter(s) or text of correct answer(s) |
| - Student's First Answer: Letter(s) or fill-in text |
| - UsedHint: <True/False> |
| - SawAnswer: <True/False> |
| - Skill: <skill_name_or_id> |
| - A new problem (with optional answer choices), skill metadata, and context flags (`UsedHint`, `SawAnswer`). |
| |
| # About the context flags: |
| - **UsedHint = True** → The student viewed or used a hint while solving this problem. |
| - **SawAnswer = True** → The student saw the correct answer before or during the attempt. |
| When either of these flags is True, treat the corresponding response as *less reliable evidence of mastery* — it indicates that the student has not fully learned the concept and required help solving the problem. |
| """ |
|
|
| prompt += "**Student's Previous Problems:**\n\n" |
| for item in student_history: |
| prompt += f"Timestamp: {item['timestamp']}\n" |
| prompt += f"Problem: {item['problem_text']}\n" |
| prompt += f"Problem Type: {item['problem_type']}\n" |
| if item.get('answer_options_formatted'): |
| prompt += f"Options:\n{item['answer_options_formatted']}\n" |
| prompt += f"Correct Answer: {item['correct_answer']}\n" |
| prompt += f"Student's First Answer: {item['student_answer']}\n" |
| prompt += f"UsedHint: {item['used_hint']}\n" |
| prompt += f"SawAnswer: {item['saw_answer']}\n" |
| if item.get('node_name'): |
| prompt += f"Skill: {item['node_name']}\n" |
| else: |
| prompt += f"Skill: Undefined\n" |
| prompt += "---\n\n" |
|
|
| prompt += "**New Problem to Predict:**\n\n" |
| prompt += f"Timestamp: {new_problem['timestamp']}\n" |
| prompt += f"Problem: {new_problem['problem_text']}\n" |
| prompt += f"Problem Type: {new_problem['problem_type']}\n" |
| if new_problem.get('answer_options_formatted'): |
| prompt += f"Answer Options:\n{new_problem['answer_options_formatted']}\n" |
| prompt += f"Correct Answer: {new_problem['correct_answer']}\n" |
| if new_problem.get('node_name'): |
| prompt += f"Skill: {new_problem['node_name']}\n" |
| else: |
| prompt += f"Skill: Undefined\n" |
|
|
| return prompt |
|
|
|
|
| def extract_json_prediction(response_text): |
| """Extract the final JSON prediction from the model's response.""" |
| |
| json_matches = re.findall(r'\{[\s\S]*?\}', response_text) |
|
|
| if json_matches: |
| |
| json_str = json_matches[-1] |
| try: |
| |
| json_str = json_str.encode().decode('unicode_escape') |
| json_str = json_str.strip() |
| return json.loads(json_str) |
| except json.JSONDecodeError as e: |
| print(f"JSON decode error: {e}") |
| print(f"Attempted to parse:\n{json_str}") |
| except Exception as e: |
| print(f"Error processing JSON: {e}") |
| return None |
|
|
|
|
| def get_prediction_id(meta): |
| """Generate unique ID for a prediction""" |
| return f"{meta['user_id']}_{meta['bin_number']}_{meta['prediction_type']}" |
|
|
|
|
| def load_completed_predictions(output_jsonl): |
| """Load already-completed prediction IDs from JSONL file""" |
| completed = set() |
| if os.path.exists(output_jsonl): |
| with open(output_jsonl, 'r') as f: |
| for line in f: |
| if line.strip(): |
| result = json.loads(line) |
| completed.add(result['prediction_id']) |
| print(f"Loaded {len(completed)} completed predictions from {output_jsonl}") |
| return completed |
|
|
|
|
| def make_process_single_user(system_prompt): |
| """Create a process_single_user function with the given system prompt.""" |
| def process_single_user(args): |
| """Process a single user's data and return prompts and metadata.""" |
| user_id, user_records, min_history, bin_size = args |
|
|
| prompts = [] |
| metadata = [] |
|
|
| |
| if len(user_records) < min_history + 1: |
| return prompts, metadata |
|
|
| num_bins = (len(user_records) - min_history) // bin_size |
|
|
| |
| student_history = [] |
| for hist_idx in range(min_history): |
| row = user_records[hist_idx] |
| student_history.append({ |
| 'problem_id': row['problem_id'], |
| 'timestamp': row['end_time'], |
| 'problem_text': row['cleaned body'], |
| 'correct_answer': row['Fill-in Answers'], |
| 'answer_options': row['answer_options'] if pd.notna(row['answer_options']) else None, |
| 'answer_options_formatted': row['answer_options_formatted'] if pd.notna(row.get('answer_options_formatted')) else None, |
| 'student_answer': row['answer_text'], |
| 'used_hint': row['hint_count'] > 0, |
| 'saw_answer': row['saw_answer'], |
| 'problem_type': row['Problem Type'], |
| 'node_name': row.get('node_name') |
| }) |
|
|
| for bin_idx in range(num_bins): |
| |
| if bin_idx > 0: |
| prev_bin_start = min_history + ((bin_idx - 1) * bin_size) |
| prev_bin_end = min_history + (bin_idx * bin_size) |
| for hist_idx in range(prev_bin_start, prev_bin_end): |
| row = user_records[hist_idx] |
| student_history.append({ |
| 'problem_id': row['problem_id'], |
| 'timestamp': row['end_time'], |
| 'problem_text': row['cleaned body'], |
| 'correct_answer': row['Fill-in Answers'], |
| 'answer_options': row['answer_options'] if pd.notna(row['answer_options']) else None, |
| 'answer_options_formatted': row['answer_options_formatted'] if pd.notna(row.get('answer_options_formatted')) else None, |
| 'student_answer': row['answer_text'], |
| 'used_hint': row['hint_count'] > 0, |
| 'saw_answer': row['saw_answer'], |
| 'problem_type': row['Problem Type'], |
| 'node_name': row.get('node_name') |
| }) |
|
|
| history_end = min_history + (bin_idx * bin_size) |
| bin_start = history_end |
| bin_end = bin_start + bin_size |
| current_bin = user_records[bin_start:bin_end] |
|
|
| |
| first_correct_idx = None |
| first_incorrect_idx = None |
|
|
| for idx, row in enumerate(current_bin): |
| actual_idx = bin_start + idx |
| score = row['discrete_score'] |
|
|
| if score == 1 and first_correct_idx is None: |
| first_correct_idx = actual_idx |
| if score == 0 and first_incorrect_idx is None: |
| first_incorrect_idx = actual_idx |
|
|
| if first_correct_idx is not None and first_incorrect_idx is not None: |
| break |
|
|
| |
| for target_idx, prediction_type in [ |
| (first_correct_idx, 'correct'), |
| (first_incorrect_idx, 'incorrect') |
| ]: |
| if target_idx is None: |
| continue |
|
|
| target_row = user_records[target_idx] |
| new_problem = { |
| 'problem_text': target_row['cleaned body'], |
| 'correct_answer': target_row['Fill-in Answers'], |
| 'answer_options': target_row['answer_options'] if pd.notna(target_row['answer_options']) else None, |
| 'answer_options_formatted': target_row['answer_options_formatted'] if pd.notna(target_row.get('answer_options_formatted')) else None, |
| 'problem_type': target_row['Problem Type'], |
| 'timestamp': target_row['end_time'], |
| 'node_name': target_row.get('node_name') |
| } |
|
|
| user_prompt = create_user_prompt(student_history, new_problem, None) |
| full_prompt = system_prompt + "\n\n" + user_prompt |
|
|
| prompts.append(full_prompt) |
| metadata.append({ |
| 'prediction_id': f"{user_id}_{bin_idx}_{prediction_type}", |
| 'row_index': target_idx, |
| 'user_id': user_id, |
| 'history_size': len(student_history), |
| 'bin_number': bin_idx, |
| 'prediction_type': prediction_type, |
| 'id': target_row.get('id_x', None), |
| 'problem_id': target_row.get('problem_id', None), |
| 'problem_type': target_row['Problem Type'], |
| 'actual_answer': target_row['answer_text'], |
| 'correct_answer': target_row['Fill-in Answers'], |
| 'actual_score': target_row['discrete_score'], |
| 'prompt': full_prompt |
| }) |
|
|
| return prompts, metadata |
|
|
| return process_single_user |
|
|
|
|
| def append_results_jsonl(results, output_jsonl): |
| """Append batch results to JSONL file""" |
| with open(output_jsonl, 'a') as f: |
| for result in results: |
| f.write(json.dumps(result, cls=NumpyEncoder) + '\n') |
|
|
|
|
| def process_batch(batch_metadata, batch_response_texts): |
| """Process a batch of responses and return results.""" |
| batch_results = [] |
|
|
| for metadata, response_text in zip(batch_metadata, batch_response_texts): |
| |
| prediction = extract_json_prediction(response_text) |
|
|
| if prediction: |
| batch_results.append({ |
| **metadata, |
| 'predicted_skill_level': prediction.get('skill_level'), |
| 'predicted_question_level': prediction.get('question_level'), |
| 'predicted_student_answer': prediction.get('student_answer'), |
| 'full_response': response_text |
| }) |
| else: |
| batch_results.append({ |
| **metadata, |
| 'predicted_skill_level': None, |
| 'predicted_question_level': None, |
| 'predicted_student_answer': None, |
| 'full_response': response_text |
| }) |
|
|
| return batch_results |
|
|
|
|
| |
| _process_single_user_func = None |
|
|
|
|
| def _process_single_user_wrapper(args): |
| """Wrapper for multiprocessing that uses the global function.""" |
| return _process_single_user_func(args) |
|
|
|
|
| def run_inference(config): |
| """ |
| Main inference function that runs KT prediction with the given model config. |
| |
| Args: |
| config: Dict with keys: |
| - model_id: HuggingFace model ID |
| - gen_configs: Dict of generation parameters |
| - output_prefix: Prefix for output filename |
| - system_prompt_prefix: Optional prefix for system prompt (e.g., "Reasoning: medium\n\n") |
| """ |
| global _process_single_user_func |
|
|
| model_id = config["model_id"] |
| gen_configs = config["gen_configs"] |
| output_prefix = config["output_prefix"] |
|
|
| |
| default_output_jsonl = f"{output_prefix}.jsonl" |
| args = parse_args(default_output_jsonl) |
|
|
| |
| |
| if args.reasoning_level is not None: |
| if args.reasoning_level == "none": |
| system_prompt_prefix = "" |
| else: |
| system_prompt_prefix = f"Reasoning: {args.reasoning_level}\n\n" |
| else: |
| system_prompt_prefix = config.get("system_prompt_prefix", "") |
|
|
| |
| system_prompt = system_prompt_prefix + BASE_SYSTEM_PROMPT |
|
|
| |
| _process_single_user_func = make_process_single_user(system_prompt) |
|
|
| batch_size = args.batch_size |
| data_dir = args.data_dir |
| cache_dir = args.cache_dir |
| num_students = args.num_students |
| bin_size = args.bin_size |
| min_history = args.min_history |
|
|
| |
| n_str = "all" if num_students <= 0 else str(num_students) |
| params_suffix = f"_n{n_str}_bin{bin_size}_hist{min_history}" |
|
|
| if args.output: |
| |
| output_jsonl = args.output |
| else: |
| |
| filename = f"{output_prefix}{params_suffix}.jsonl" |
| output_jsonl = os.path.join(args.output_dir, filename) |
|
|
| |
| student_csv = os.path.join(data_dir, STUDENT_FILE) |
| problems_csv = os.path.join(data_dir, PROBLEMS_FILE) |
| skill_csv = os.path.join(data_dir, SKILL_FILE) |
|
|
| print(f"Model: {model_id}") |
| print(f"Data directory: {data_dir}") |
| print(f"Batch size: {batch_size}") |
| print(f"Output JSONL: {output_jsonl}") |
| print(f"Num students: {num_students if num_students > 0 else 'all'}") |
| print(f"Bin size: {bin_size}") |
| print(f"Min history: {min_history}") |
| if cache_dir: |
| print(f"Model cache: {cache_dir}") |
| print(f"Text cleaner: {'legacy (cleantext.py)' if args.legacy_clean else 'default (clean_utils.py)'}") |
|
|
| |
| print("\nLoading data...") |
| student_df = pd.read_csv(student_csv) |
| student_df = student_df.sort_values(['user_id', 'id']).reset_index(drop=True) |
| problems_df = pd.read_csv(problems_csv) |
| clean_func = clean_text_legacy if args.legacy_clean else clean_problem_body |
| problems_df['cleaned body'] = problems_df['Problem Body'].apply(clean_func) |
|
|
| |
| problems_df['answer_options'] = problems_df['Multiple Choice Options'].apply(label_answer_options) |
|
|
| |
| problems_df['correct_answers'] = problems_df.apply( |
| lambda row: get_correct_option_letters(row['answer_options'], row['Multiple Choice Answers']) |
| if row['Problem Type'] in ['Multiple Choice (select 1)', 'Multiple Choice (select all)'] |
| else row['Fill-in Answers'], |
| axis=1 |
| ) |
|
|
| skill_df = pd.read_csv(skill_csv) |
| problems_df = pd.merge(problems_df, skill_df, on='problem_id', how='left') |
|
|
| |
| problems_df['answer_options_formatted'] = problems_df['answer_options'].apply( |
| lambda x: format_answer_options_for_prompt(x) if pd.notna(x) else None |
| ) |
|
|
| |
| student_df = student_df.sort_values('id').reset_index(drop=True) |
|
|
| |
| merged_df = student_df.merge(problems_df, on='problem_id', how='inner') |
|
|
| |
| merged_df['answer_text'] = merged_df.apply( |
| lambda row: match_student_answer_to_letters(row['answer_text'], row['answer_options']) |
| if row['Problem Type'] in ['Multiple Choice (select 1)', 'Multiple Choice (select all)'] and pd.notna(row['answer_options']) |
| else row['answer_text'], |
| axis=1 |
| ) |
|
|
| |
| all_users = merged_df['user_id'].unique() |
| if num_students <= 0: |
| |
| selected_users = all_users |
| print(f"\nUsing all {len(all_users)} users") |
| else: |
| |
| np.random.seed(42) |
| selected_users = np.random.choice(all_users, size=min(num_students, len(all_users)), replace=False) |
| merged_df = merged_df[merged_df['user_id'].isin(selected_users)] |
| print(f"\nSelected {len(selected_users)} random users from {len(all_users)} total users") |
| print(f"Filtered data: {len(merged_df)} rows") |
|
|
| |
| print("\nPreparing prompts in parallel...") |
|
|
| |
| print("Grouping user data...") |
| user_groups = [ |
| (user_id, user_df.to_dict('records'), min_history, bin_size) |
| for user_id, user_df in merged_df.groupby('user_id') |
| ] |
| print(f"Processing {len(user_groups)} users with {cpu_count()} CPU cores...") |
|
|
| |
| all_prompts = [] |
| all_metadata = [] |
|
|
| with Pool(processes=cpu_count()) as pool: |
| results = list(tqdm( |
| pool.imap(_process_single_user_wrapper, user_groups), |
| total=len(user_groups), |
| desc="Preparing prompts" |
| )) |
|
|
| |
| for prompts, metadata in results: |
| all_prompts.extend(prompts) |
| all_metadata.extend(metadata) |
|
|
| print(f"\nTotal predictions to make: {len(all_prompts)}") |
|
|
| |
| completed_ids = load_completed_predictions(output_jsonl) |
| remaining = [(p, m) for p, m in zip(all_prompts, all_metadata) |
| if m['prediction_id'] not in completed_ids] |
|
|
| if not remaining: |
| print("All predictions already completed!") |
| return |
|
|
| all_prompts, all_metadata = zip(*remaining) |
| all_prompts = list(all_prompts) |
| all_metadata = list(all_metadata) |
|
|
| print(f"Already completed: {len(completed_ids)}") |
| print(f"Remaining to process: {len(all_prompts)}") |
| print(f"Processing in batches of {batch_size}") |
|
|
| |
| print("\nInitializing vLLM engine...") |
| sampling_params = SamplingParams(**gen_configs) |
| llm_kwargs = { |
| "model": model_id, |
| "tensor_parallel_size": args.num_gpus, |
| "trust_remote_code": True, |
| "gpu_memory_utilization": args.gpu_memory_utilization, |
| "enable_prefix_caching": True, |
| } |
| if args.max_num_seqs is not None: |
| llm_kwargs["max_num_seqs"] = args.max_num_seqs |
| if args.max_model_len is not None: |
| llm_kwargs["max_model_len"] = args.max_model_len |
| if cache_dir: |
| llm_kwargs["download_dir"] = cache_dir |
| llm = LLM(**llm_kwargs) |
|
|
| |
| results = [] |
| num_batches = (len(all_prompts) + batch_size - 1) // batch_size |
|
|
| for batch_idx in range(num_batches): |
| batch_start = batch_idx * batch_size |
| batch_end = min(batch_start + batch_size, len(all_prompts)) |
|
|
| batch_prompts = all_prompts[batch_start:batch_end] |
| batch_metadata = all_metadata[batch_start:batch_end] |
|
|
| print(f"\n{'='*80}") |
| print(f"Processing batch {batch_idx + 1}/{num_batches}") |
| print(f"Items: {batch_start} to {batch_end} ({len(batch_prompts)} prompts)") |
| print(f"{'='*80}") |
|
|
| |
| try: |
| outputs = llm.generate(batch_prompts, sampling_params) |
| response_texts = [o.outputs[0].text.strip() for o in outputs] |
|
|
| |
| batch_results = process_batch(batch_metadata, response_texts) |
| results.extend(batch_results) |
|
|
| print(f"Successfully processed batch {batch_idx + 1}") |
| print(f"Total results so far: {len(results)}") |
|
|
| |
| append_results_jsonl(batch_results, output_jsonl) |
| print(f"Saved {len(batch_results)} results to {output_jsonl}") |
|
|
| except Exception as e: |
| print(f"\nERROR processing batch {batch_idx + 1}: {str(e)}") |
| print(f"Progress saved in {output_jsonl} - restart to resume") |
| raise |
|
|
| print(f"\n{'='*80}") |
| print("All batches processed successfully!") |
| print(f"{'='*80}") |
| print(f"\nAll results saved to {output_jsonl}") |
| print(f"Total predictions processed: {len(results)}") |
|
|
| |
| print("\nCleaning up...") |
| destroy_model_parallel() |
| destroy_distributed_environment() |
| del llm |
| with contextlib.suppress(AssertionError): |
| torch.distributed.destroy_process_group() |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| print("\nDone!") |
| exit(0) |
|
|