Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- assignment_llm_1/assignment_image/data/cifar-10-batches-py/batches.meta +0 -0
- assignment_llm_1/assignment_text/code/c1.py +685 -0
- assignment_llm_1/assignment_text/code/c1_analysis.py +253 -0
- assignment_llm_1/assignment_text/code/explanation_creation.py +174 -0
- assignment_llm_1/assignment_text/documentation/different_model_size_and_performance.md +77 -0
- assignment_llm_1/assignment_text/documentation/documentation.md +142 -0
- assignment_llm_1/assignment_text/documentation/error_analysis.json +92 -0
- assignment_llm_1/assignment_text/saved_model/Untitled +1 -0
- assignment_llm_1/assignment_text/saved_model/transformer_imdb_experiment_report.md +77 -0
- assignment_llm_1/data/cifar-10-batches-py/batches.meta +0 -0
- assignment_llm_1/data/cifar-10-batches-py/readme.html +1 -0
- code/RL_model/inference_data/RL_model_inference_v1.jsonl +0 -0
- code/RL_model/inference_data/inference_20260213_002423.jsonl +2 -0
- code/RL_model/inference_data/vllm_inference_20260213_003845.jsonl +0 -0
- code/RL_model/inference_data/vllm_inference_20260213_003845_meta.json +10 -0
- code/RL_model/inference_data/vllm_inference_20260213_165923.jsonl +0 -0
- code/RL_model/inference_data/vllm_inference_20260213_170937_meta.json +10 -0
- code/RL_model/inference_data/vllm_inference_qwen-qwen3-4b-instruct-2507_20260213_173334_meta.json +10 -0
- code/RL_model/unsloth_rl/RL_code.py +165 -0
- code/RL_model/unsloth_rl/RL_training.ipynb +475 -0
- code/RL_model/unsloth_rl/claim_verifier.py +175 -0
- code/RL_model/unsloth_rl/finetune.py +91 -0
- code/RL_model/unsloth_rl/health_classifier.py +42 -0
- code/RL_model/unsloth_rl/highlighter.py +103 -0
- code/RL_model/unsloth_rl/inference.py +120 -0
- code/RL_model/unsloth_rl/prompt +58 -0
- code/RL_model/unsloth_rl/reward_mock.py +127 -0
- code/RL_model/unsloth_rl/test_reward_mock_unittest.py +139 -0
- code/RL_model/unsloth_rl/testing.py +215 -0
- code/RL_model/unsloth_rl/testing_v2.py +138 -0
- code/RL_model/verl/Search-R1/.gitignore +122 -0
- code/RL_model/verl/Search-R1/LICENSE +202 -0
- code/RL_model/verl/Search-R1/Notice.txt +1 -0
- code/RL_model/verl/Search-R1/README.md +275 -0
- code/RL_model/verl/Search-R1/VERL_README.md +103 -0
- code/RL_model/verl/Search-R1/infer.py +128 -0
- code/RL_model/verl/Search-R1/llm_guard_3B_10k_v2.log +180 -0
- code/RL_model/verl/Search-R1/pyproject.toml +78 -0
- code/RL_model/verl/Search-R1/requirements.txt +16 -0
- code/RL_model/verl/Search-R1/retrieval_launch.sh +13 -0
- code/RL_model/verl/Search-R1/setup.py +54 -0
- code/RL_model/verl/Search-R1/train_grpo.sh +46 -0
- code/RL_model/verl/Search-R1/train_ppo.sh +90 -0
- code/RL_model/verl/verl_train/.git-blame-ignore-revs +13 -0
- code/RL_model/verl/verl_train/.gitignore +130 -0
- code/RL_model/verl/verl_train/.gitmodules +3 -0
- code/RL_model/verl/verl_train/.log +0 -0
- code/RL_model/verl/verl_train/.pre-commit-config.yaml +45 -0
- code/RL_model/verl/verl_train/.readthedocs.yaml +19 -0
- code/RL_model/verl/verl_train/CONTRIBUTING.md +90 -0
assignment_llm_1/assignment_image/data/cifar-10-batches-py/batches.meta
ADDED
|
Binary file (158 Bytes). View file
|
|
|
assignment_llm_1/assignment_text/code/c1.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
|
| 4 |
+
# Configure CUDA visibility (set this as appropriate for your environment).
|
| 5 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 6 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import random
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.optim as optim
|
| 16 |
+
from collections import Counter
|
| 17 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
| 18 |
+
from sklearn.model_selection import train_test_split
|
| 19 |
+
from torch.utils.data import DataLoader, Dataset
|
| 20 |
+
from datasets import load_dataset
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
Homework 1 (Part I) – Transformer-based sentiment analysis on the IMDB dataset.
|
| 24 |
+
|
| 25 |
+
This script implements:
|
| 26 |
+
- Data loading and preprocessing for the IMDB movie review dataset
|
| 27 |
+
- A Transformer-based text classification model
|
| 28 |
+
- Training and evaluation loops for binary sentiment analysis
|
| 29 |
+
- Saving of the trained model together with vocabulary and configuration
|
| 30 |
+
|
| 31 |
+
The code is organized into clearly separated sections:
|
| 32 |
+
1) Data preparation and tokenization
|
| 33 |
+
2) Transformer components (building blocks)
|
| 34 |
+
3) Full Transformer classifier
|
| 35 |
+
4) Training and evaluation logic
|
| 36 |
+
5) Execution example using a train/validation split of IMDB
|
| 37 |
+
|
| 38 |
+
Model Analysis and Improvement:
|
| 39 |
+
1. After evaluation, delve into analyzing your model's behavior to identify
|
| 40 |
+
areas for improvement and fine-tuning.
|
| 41 |
+
2. Analyze translation errors (if applicable): Examine specific translation
|
| 42 |
+
examples where the model performs poorly and try to understand the reasons
|
| 43 |
+
behind these errors. Are there issues with handling rare words or
|
| 44 |
+
idiomatic expressions?
|
| 45 |
+
3. Explore the impact of model size: Experiment with different Transformer
|
| 46 |
+
model sizes (e.g., small, medium, large) to understand how model
|
| 47 |
+
complexity affects performance.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
# ==========================================
|
| 51 |
+
# 1. Data Preparation & Tokenization
|
| 52 |
+
# ==========================================
|
| 53 |
+
|
| 54 |
+
def tokenize(text):
|
| 55 |
+
"""
|
| 56 |
+
Tokenize a raw review string into a list of normalized word tokens.
|
| 57 |
+
|
| 58 |
+
Steps:
|
| 59 |
+
- Convert to lowercase
|
| 60 |
+
- Remove HTML line breaks
|
| 61 |
+
- Remove non-alphanumeric characters (except whitespace)
|
| 62 |
+
- Split on whitespace
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
text (str): Raw review text.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
List[str]: List of token strings.
|
| 69 |
+
"""
|
| 70 |
+
text = text.lower()
|
| 71 |
+
text = re.sub(r"<br />", " ", text) # Remove HTML line breaks
|
| 72 |
+
text = re.sub(r"[^a-zA-Z0-9\s]", "", text)
|
| 73 |
+
return text.split()
|
| 74 |
+
|
| 75 |
+
class IMDBDataset(Dataset):
|
| 76 |
+
"""
|
| 77 |
+
Torch Dataset wrapper for IMDB sequences and labels.
|
| 78 |
+
|
| 79 |
+
Each item corresponds to:
|
| 80 |
+
- a fixed-length sequence of token IDs
|
| 81 |
+
- a sentiment label (0 = negative, 1 = positive)
|
| 82 |
+
"""
|
| 83 |
+
def __init__(self, sequences, labels):
|
| 84 |
+
self.sequences = torch.tensor(sequences, dtype=torch.long)
|
| 85 |
+
self.labels = torch.tensor(labels, dtype=torch.long)
|
| 86 |
+
|
| 87 |
+
def __len__(self):
|
| 88 |
+
return len(self.labels)
|
| 89 |
+
|
| 90 |
+
def __getitem__(self, idx):
|
| 91 |
+
return self.sequences[idx], self.labels[idx]
|
| 92 |
+
|
| 93 |
+
def build_vocab(texts, max_vocab_size=10000):
|
| 94 |
+
"""
|
| 95 |
+
Build a word-to-index vocabulary from a collection of texts.
|
| 96 |
+
|
| 97 |
+
The vocabulary is constructed using token frequency counts from the
|
| 98 |
+
training set only to avoid information leakage. Two special tokens
|
| 99 |
+
are always included:
|
| 100 |
+
- "<PAD>" mapped to index 0
|
| 101 |
+
- "<UNK>" mapped to index 1
|
| 102 |
+
|
| 103 |
+
The remaining (max_vocab_size - 2) most frequent tokens are added.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
texts (Iterable[str]): Training texts.
|
| 107 |
+
max_vocab_size (int): Maximum size of the vocabulary.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Dict[str, int]: Mapping from token string to integer index.
|
| 111 |
+
"""
|
| 112 |
+
counter = Counter()
|
| 113 |
+
for text in texts:
|
| 114 |
+
counter.update(tokenize(text))
|
| 115 |
+
|
| 116 |
+
# Reserve 0 for padding and 1 for unknown tokens
|
| 117 |
+
vocab = {"<PAD>": 0, "<UNK>": 1}
|
| 118 |
+
common_words = counter.most_common(max_vocab_size - 2)
|
| 119 |
+
for word, _ in common_words:
|
| 120 |
+
vocab[word] = len(vocab)
|
| 121 |
+
return vocab
|
| 122 |
+
|
| 123 |
+
def preprocess_data(texts, vocab, max_len=128):
|
| 124 |
+
"""
|
| 125 |
+
Convert raw texts into padded/truncated sequences of token IDs.
|
| 126 |
+
|
| 127 |
+
Steps:
|
| 128 |
+
- Tokenize each text
|
| 129 |
+
- Map tokens to vocabulary indices (using <UNK> for OOV tokens)
|
| 130 |
+
- Truncate to max_len or pad with <PAD> to reach max_len
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
texts (Iterable[str]): Input texts (reviews).
|
| 134 |
+
vocab (Dict[str, int]): Token-to-index mapping.
|
| 135 |
+
max_len (int): Maximum sequence length in tokens.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
np.ndarray: Array of shape (num_examples, max_len) with dtype int.
|
| 139 |
+
"""
|
| 140 |
+
sequences = []
|
| 141 |
+
for text in texts:
|
| 142 |
+
tokens = tokenize(text)
|
| 143 |
+
token_ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
|
| 144 |
+
# Pad or Truncate
|
| 145 |
+
if len(token_ids) < max_len:
|
| 146 |
+
token_ids += [vocab["<PAD>"]] * (max_len - len(token_ids))
|
| 147 |
+
else:
|
| 148 |
+
token_ids = token_ids[:max_len]
|
| 149 |
+
sequences.append(token_ids)
|
| 150 |
+
return np.array(sequences)
|
| 151 |
+
|
| 152 |
+
# ==========================================
|
| 153 |
+
# 2. Transformer Components
|
| 154 |
+
# ==========================================
|
| 155 |
+
|
| 156 |
+
class PositionalEncoding(nn.Module):
|
| 157 |
+
"""
|
| 158 |
+
Sinusoidal positional encoding module.
|
| 159 |
+
|
| 160 |
+
Implements the deterministic positional encoding from the original
|
| 161 |
+
Transformer paper ("Attention is All You Need"), which is added to
|
| 162 |
+
token embeddings to inject information about token positions.
|
| 163 |
+
"""
|
| 164 |
+
def __init__(self, d_model, max_len=5000):
|
| 165 |
+
super().__init__()
|
| 166 |
+
pe = torch.zeros(max_len, d_model)
|
| 167 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 168 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 169 |
+
|
| 170 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 171 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 172 |
+
|
| 173 |
+
self.register_buffer('pe', pe.unsqueeze(0))
|
| 174 |
+
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
"""
|
| 177 |
+
Add positional encodings to input embeddings.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
x (Tensor): Input tensor of shape [batch_size, seq_len, d_model].
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Tensor: Positionally encoded representations with same shape as x.
|
| 184 |
+
"""
|
| 185 |
+
return x + self.pe[:, :x.size(1)]
|
| 186 |
+
|
| 187 |
+
class MultiHeadAttention(nn.Module):
|
| 188 |
+
"""
|
| 189 |
+
Multi-head self-attention mechanism.
|
| 190 |
+
|
| 191 |
+
For each token, attention is computed over all tokens in the sequence
|
| 192 |
+
(including itself) using multiple attention heads. Each head operates
|
| 193 |
+
in its own subspace and the outputs are concatenated.
|
| 194 |
+
"""
|
| 195 |
+
def __init__(self, d_model, num_heads):
|
| 196 |
+
super().__init__()
|
| 197 |
+
assert d_model % num_heads == 0
|
| 198 |
+
self.d_model = d_model
|
| 199 |
+
self.num_heads = num_heads
|
| 200 |
+
self.d_k = d_model // num_heads
|
| 201 |
+
|
| 202 |
+
self.W_q = nn.Linear(d_model, d_model)
|
| 203 |
+
self.W_k = nn.Linear(d_model, d_model)
|
| 204 |
+
self.W_v = nn.Linear(d_model, d_model)
|
| 205 |
+
self.W_o = nn.Linear(d_model, d_model)
|
| 206 |
+
|
| 207 |
+
def forward(self, x, mask=None):
|
| 208 |
+
"""
|
| 209 |
+
Apply multi-head self-attention to the input sequence.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
x (Tensor): Input tensor of shape [batch_size, seq_len, d_model].
|
| 213 |
+
mask (Tensor, optional): Attention mask of shape
|
| 214 |
+
[batch_size, 1, 1, seq_len] or broadcastable equivalent,
|
| 215 |
+
where positions with 0 are masked out.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
Tensor: Output tensor of shape [batch_size, seq_len, d_model].
|
| 219 |
+
"""
|
| 220 |
+
batch_size, seq_len, _ = x.shape
|
| 221 |
+
|
| 222 |
+
# Linear projections
|
| 223 |
+
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
| 224 |
+
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
| 225 |
+
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
| 226 |
+
|
| 227 |
+
# Scaled Dot-Product Attention
|
| 228 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 229 |
+
if mask is not None:
|
| 230 |
+
scores = scores.masked_fill(mask == 0, -1e9)
|
| 231 |
+
|
| 232 |
+
attn = torch.softmax(scores, dim=-1)
|
| 233 |
+
context = torch.matmul(attn, V)
|
| 234 |
+
|
| 235 |
+
# Concatenate heads
|
| 236 |
+
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
| 237 |
+
return self.W_o(context)
|
| 238 |
+
|
| 239 |
+
class TransformerEncoderBlock(nn.Module):
|
| 240 |
+
"""
|
| 241 |
+
Single Transformer encoder block consisting of:
|
| 242 |
+
- multi-head self-attention sublayer (with residual + layer norm)
|
| 243 |
+
- position-wise feed-forward sublayer (with residual + layer norm)
|
| 244 |
+
"""
|
| 245 |
+
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
|
| 246 |
+
super().__init__()
|
| 247 |
+
self.mha = MultiHeadAttention(d_model, num_heads)
|
| 248 |
+
self.ffn = nn.Sequential(
|
| 249 |
+
nn.Linear(d_model, d_ff),
|
| 250 |
+
nn.ReLU(),
|
| 251 |
+
nn.Linear(d_ff, d_model)
|
| 252 |
+
)
|
| 253 |
+
self.layernorm1 = nn.LayerNorm(d_model)
|
| 254 |
+
self.layernorm2 = nn.LayerNorm(d_model)
|
| 255 |
+
self.dropout = nn.Dropout(dropout)
|
| 256 |
+
|
| 257 |
+
def forward(self, x, mask=None):
|
| 258 |
+
"""
|
| 259 |
+
Forward pass through one encoder block.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
x (Tensor): Input tensor of shape [batch_size, seq_len, d_model].
|
| 263 |
+
mask (Tensor, optional): Attention mask (see MultiHeadAttention).
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
Tensor: Output tensor of shape [batch_size, seq_len, d_model].
|
| 267 |
+
"""
|
| 268 |
+
# Sublayer 1: self-attention with residual connection
|
| 269 |
+
attn_out = self.mha(x, mask)
|
| 270 |
+
x = self.layernorm1(x + self.dropout(attn_out))
|
| 271 |
+
# Sublayer 2: position-wise feed-forward network with residual
|
| 272 |
+
ffn_out = self.ffn(x)
|
| 273 |
+
x = self.layernorm2(x + self.dropout(ffn_out))
|
| 274 |
+
return x
|
| 275 |
+
|
| 276 |
+
# ==========================================
|
| 277 |
+
# 3. Full Transformer Classifier
|
| 278 |
+
# ==========================================
|
| 279 |
+
|
| 280 |
+
class TransformerClassifier(nn.Module):
|
| 281 |
+
"""
|
| 282 |
+
Transformer-based text classifier for IMDB sentiment analysis.
|
| 283 |
+
|
| 284 |
+
Architecture:
|
| 285 |
+
- Token embedding layer
|
| 286 |
+
- Sinusoidal positional encoding
|
| 287 |
+
- Stack of Transformer encoder blocks
|
| 288 |
+
- Global average pooling over sequence dimension
|
| 289 |
+
- Linear classification head to predict sentiment label
|
| 290 |
+
"""
|
| 291 |
+
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len, num_classes=2, dropout=0.1):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 294 |
+
self.pos_encoding = PositionalEncoding(d_model, max_len)
|
| 295 |
+
|
| 296 |
+
self.encoder_layers = nn.ModuleList([
|
| 297 |
+
TransformerEncoderBlock(d_model, num_heads, d_ff, dropout)
|
| 298 |
+
for _ in range(num_layers)
|
| 299 |
+
])
|
| 300 |
+
|
| 301 |
+
self.dropout = nn.Dropout(dropout)
|
| 302 |
+
# Classification Head: Flatten or Global Pool
|
| 303 |
+
self.classifier = nn.Linear(d_model, num_classes)
|
| 304 |
+
|
| 305 |
+
def forward(self, x, mask=None):
|
| 306 |
+
"""
|
| 307 |
+
Forward pass for the classifier.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
x (Tensor): Input tensor of token IDs
|
| 311 |
+
with shape [batch_size, seq_len].
|
| 312 |
+
mask (Tensor, optional): Attention mask (not used in this script).
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
Tensor: Logits of shape [batch_size, num_classes].
|
| 316 |
+
"""
|
| 317 |
+
x = self.dropout(self.pos_encoding(self.embedding(x)))
|
| 318 |
+
|
| 319 |
+
for layer in self.encoder_layers:
|
| 320 |
+
x = layer(x, mask)
|
| 321 |
+
|
| 322 |
+
# Global Average Pooling across the sequence dimension
|
| 323 |
+
x = x.mean(dim=1)
|
| 324 |
+
return self.classifier(x)
|
| 325 |
+
|
| 326 |
+
# ==========================================
|
| 327 |
+
# 4. Training and Evaluation Logic
|
| 328 |
+
# ==========================================
|
| 329 |
+
|
| 330 |
+
def train_model(model, train_loader, val_loader, epochs, lr, device):
|
| 331 |
+
"""
|
| 332 |
+
Train the Transformer classifier on the IMDB training split.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
model (nn.Module): TransformerClassifier instance.
|
| 336 |
+
train_loader (DataLoader): Batches of (sequence, label) for training.
|
| 337 |
+
val_loader (DataLoader): Batches for validation.
|
| 338 |
+
epochs (int): Number of full passes through the training set.
|
| 339 |
+
lr (float): Initial learning rate for Adam optimizer.
|
| 340 |
+
device (torch.device): Device on which to run training.
|
| 341 |
+
|
| 342 |
+
Uses:
|
| 343 |
+
- CrossEntropyLoss for binary sentiment classification.
|
| 344 |
+
- Adam optimizer with StepLR scheduler (gamma=0.5 every 2 epochs).
|
| 345 |
+
"""
|
| 346 |
+
criterion = nn.CrossEntropyLoss()
|
| 347 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 348 |
+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)
|
| 349 |
+
|
| 350 |
+
model.to(device)
|
| 351 |
+
|
| 352 |
+
for epoch in range(epochs):
|
| 353 |
+
model.train()
|
| 354 |
+
total_loss = 0
|
| 355 |
+
for batch_seq, batch_lab in train_loader:
|
| 356 |
+
batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device)
|
| 357 |
+
|
| 358 |
+
optimizer.zero_grad()
|
| 359 |
+
outputs = model(batch_seq)
|
| 360 |
+
loss = criterion(outputs, batch_lab)
|
| 361 |
+
loss.backward()
|
| 362 |
+
optimizer.step()
|
| 363 |
+
total_loss += loss.item()
|
| 364 |
+
|
| 365 |
+
scheduler.step()
|
| 366 |
+
val_metrics = evaluate_model(model, val_loader, device)
|
| 367 |
+
val_acc = val_metrics["accuracy"]
|
| 368 |
+
val_p = val_metrics["precision"]
|
| 369 |
+
val_r = val_metrics["recall"]
|
| 370 |
+
val_f1 = val_metrics["f1"]
|
| 371 |
+
print(
|
| 372 |
+
f"Epoch {epoch+1}/{epochs} | "
|
| 373 |
+
f"Loss: {total_loss/len(train_loader):.4f} | "
|
| 374 |
+
f"Val Acc: {val_acc:.4f} | "
|
| 375 |
+
f"Val P: {val_p:.4f} | Val R: {val_r:.4f} | Val F1: {val_f1:.4f}"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
def evaluate_model(model, loader, device):
|
| 379 |
+
"""
|
| 380 |
+
Evaluate the model on a dataset.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
model (nn.Module): Trained (or partially trained) classifier.
|
| 384 |
+
loader (DataLoader): DataLoader for validation or test data.
|
| 385 |
+
device (torch.device): Device on which to perform evaluation.
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
Dict[str, float]: Dictionary with accuracy, precision, recall, and F1.
|
| 389 |
+
"""
|
| 390 |
+
model.eval()
|
| 391 |
+
all_preds = []
|
| 392 |
+
all_labels = []
|
| 393 |
+
|
| 394 |
+
with torch.no_grad():
|
| 395 |
+
for batch_seq, batch_lab in loader:
|
| 396 |
+
batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device)
|
| 397 |
+
outputs = model(batch_seq)
|
| 398 |
+
preds = torch.argmax(outputs, dim=1)
|
| 399 |
+
all_preds.extend(preds.cpu().numpy())
|
| 400 |
+
all_labels.extend(batch_lab.cpu().numpy())
|
| 401 |
+
|
| 402 |
+
acc = accuracy_score(all_labels, all_preds)
|
| 403 |
+
p, r, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
|
| 404 |
+
return {"accuracy": acc, "precision": p, "recall": r, "f1": f1}
|
| 405 |
+
|
| 406 |
+
def count_trainable_parameters(model):
|
| 407 |
+
"""
|
| 408 |
+
Count the number of trainable parameters in a model.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
model (nn.Module): Model whose parameters should be counted.
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
int: Number of trainable parameters.
|
| 415 |
+
"""
|
| 416 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 417 |
+
|
| 418 |
+
def write_experiment_report_md(
|
| 419 |
+
report_path,
|
| 420 |
+
results,
|
| 421 |
+
best_result,
|
| 422 |
+
device,
|
| 423 |
+
train_size,
|
| 424 |
+
val_size,
|
| 425 |
+
):
|
| 426 |
+
"""
|
| 427 |
+
Write a Markdown report summarizing model-size experiment results.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
report_path (str): Output Markdown file path.
|
| 431 |
+
results (List[Dict]): Per-model experiment outputs.
|
| 432 |
+
best_result (Dict): Best-performing entry from `results`.
|
| 433 |
+
device (torch.device): Device used during training.
|
| 434 |
+
train_size (int): Number of training samples.
|
| 435 |
+
val_size (int): Number of validation samples.
|
| 436 |
+
"""
|
| 437 |
+
lines = []
|
| 438 |
+
lines.append("# IMDB Transformer Model-Size Experiment Report")
|
| 439 |
+
lines.append("")
|
| 440 |
+
lines.append(f"- Generated at: `{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}`")
|
| 441 |
+
lines.append(f"- Device: `{device}`")
|
| 442 |
+
lines.append(f"- Training samples: `{train_size}`")
|
| 443 |
+
lines.append(f"- Validation samples: `{val_size}`")
|
| 444 |
+
lines.append(f"- Max vocab size: `{MAX_VOCAB}`")
|
| 445 |
+
lines.append(f"- Max sequence length: `{MAX_LEN}`")
|
| 446 |
+
lines.append(f"- Batch size: `{BATCH_SIZE}`")
|
| 447 |
+
lines.append(f"- Epochs: `{EPOCHS}`")
|
| 448 |
+
lines.append(f"- Learning rate: `{LR}`")
|
| 449 |
+
lines.append("")
|
| 450 |
+
|
| 451 |
+
lines.append("## Overall Comparison")
|
| 452 |
+
lines.append("")
|
| 453 |
+
lines.append("| Model Size | Trainable Params | Accuracy | Precision | Recall | F1 | Checkpoint |")
|
| 454 |
+
lines.append("|---|---:|---:|---:|---:|---:|---|")
|
| 455 |
+
for item in results:
|
| 456 |
+
metrics = item["metrics"]
|
| 457 |
+
lines.append(
|
| 458 |
+
f"| {item['size']} | {item['params']:,} | "
|
| 459 |
+
f"{metrics['accuracy']:.4f} | {metrics['precision']:.4f} | "
|
| 460 |
+
f"{metrics['recall']:.4f} | {metrics['f1']:.4f} | "
|
| 461 |
+
f"`{item['checkpoint_path']}` |"
|
| 462 |
+
)
|
| 463 |
+
lines.append("")
|
| 464 |
+
|
| 465 |
+
lines.append("## Best Model")
|
| 466 |
+
lines.append("")
|
| 467 |
+
lines.append(f"- Best size by validation F1: `{best_result['size']}`")
|
| 468 |
+
lines.append(f"- Checkpoint: `{best_result['checkpoint_path']}`")
|
| 469 |
+
lines.append(f"- Trainable parameters: `{best_result['params']:,}`")
|
| 470 |
+
lines.append("- Metrics:")
|
| 471 |
+
lines.append(f" - Accuracy: `{best_result['metrics']['accuracy']:.4f}`")
|
| 472 |
+
lines.append(f" - Precision: `{best_result['metrics']['precision']:.4f}`")
|
| 473 |
+
lines.append(f" - Recall: `{best_result['metrics']['recall']:.4f}`")
|
| 474 |
+
lines.append(f" - F1: `{best_result['metrics']['f1']:.4f}`")
|
| 475 |
+
lines.append("")
|
| 476 |
+
|
| 477 |
+
lines.append("## Per-Model Details")
|
| 478 |
+
lines.append("")
|
| 479 |
+
for item in results:
|
| 480 |
+
cfg = item["config"]
|
| 481 |
+
metrics = item["metrics"]
|
| 482 |
+
lines.append(f"### {item['size'].capitalize()} model")
|
| 483 |
+
lines.append("")
|
| 484 |
+
lines.append("- Architecture:")
|
| 485 |
+
lines.append(f" - `d_model`: `{cfg['d_model']}`")
|
| 486 |
+
lines.append(f" - `num_heads`: `{cfg['num_heads']}`")
|
| 487 |
+
lines.append(f" - `num_layers`: `{cfg['num_layers']}`")
|
| 488 |
+
lines.append(f" - `d_ff`: `{cfg['d_ff']}`")
|
| 489 |
+
lines.append(f"- Trainable params: `{item['params']:,}`")
|
| 490 |
+
lines.append(f"- Checkpoint: `{item['checkpoint_path']}`")
|
| 491 |
+
lines.append("- Validation metrics:")
|
| 492 |
+
lines.append(f" - Accuracy: `{metrics['accuracy']:.4f}`")
|
| 493 |
+
lines.append(f" - Precision: `{metrics['precision']:.4f}`")
|
| 494 |
+
lines.append(f" - Recall: `{metrics['recall']:.4f}`")
|
| 495 |
+
lines.append(f" - F1: `{metrics['f1']:.4f}`")
|
| 496 |
+
lines.append("")
|
| 497 |
+
|
| 498 |
+
with open(report_path, "w", encoding="utf-8") as f:
|
| 499 |
+
f.write("\n".join(lines))
|
| 500 |
+
|
| 501 |
+
# ==========================================
|
| 502 |
+
# 5. Execution Example (Subset of IMDB)
|
| 503 |
+
# ==========================================
|
| 504 |
+
|
| 505 |
+
# Dataset loading using the real IMDB dataset via HuggingFace datasets.
|
| 506 |
+
# Data source:
|
| 507 |
+
# HuggingFace Datasets – "imdb" configuration, which originates from the
|
| 508 |
+
# Large Movie Review Dataset (Maas et al., 2011).
|
| 509 |
+
def load_imdb_texts(split: str = "train"):
|
| 510 |
+
"""
|
| 511 |
+
Load IMDB dataset texts and labels using `datasets.load_dataset`.
|
| 512 |
+
|
| 513 |
+
Args:
|
| 514 |
+
split (str): Dataset split, e.g. "train" or "test".
|
| 515 |
+
|
| 516 |
+
Returns:
|
| 517 |
+
Tuple[List[str], List[int]]: List of review texts and sentiment labels,
|
| 518 |
+
where labels are integers 0 (negative) and 1 (positive).
|
| 519 |
+
"""
|
| 520 |
+
ds = load_dataset("imdb", split=split)
|
| 521 |
+
texts = ds["text"]
|
| 522 |
+
labels = ds["label"]
|
| 523 |
+
return texts, labels
|
| 524 |
+
|
| 525 |
+
# ===========================
|
| 526 |
+
# Hyperparameters
|
| 527 |
+
# ===========================
|
| 528 |
+
# MAX_VOCAB: upper bound on vocabulary size. Larger values can capture more
|
| 529 |
+
# rare words but increase model size and memory usage.
|
| 530 |
+
MAX_VOCAB = 5000
|
| 531 |
+
# MAX_LEN: maximum number of tokens per review. Longer sequences capture
|
| 532 |
+
# more context but are more expensive to process; here we use 64 for speed.
|
| 533 |
+
MAX_LEN = 64
|
| 534 |
+
# BATCH_SIZE: number of examples per optimization step. Larger batches yield
|
| 535 |
+
# smoother gradients but require more memory.
|
| 536 |
+
BATCH_SIZE = 32
|
| 537 |
+
# EPOCHS: number of full passes through the training dataset.
|
| 538 |
+
EPOCHS = 5
|
| 539 |
+
# LR: initial learning rate for the Adam optimizer.
|
| 540 |
+
LR = 0.001
|
| 541 |
+
|
| 542 |
+
# Transformer size presets for model-complexity experiments.
|
| 543 |
+
# Each preset controls hidden size, attention heads, number of layers,
|
| 544 |
+
# and feed-forward dimension.
|
| 545 |
+
MODEL_SIZES = {
|
| 546 |
+
"small": {"d_model": 64, "num_heads": 4, "num_layers": 1, "d_ff": 128},
|
| 547 |
+
"medium": {"d_model": 128, "num_heads": 8, "num_layers": 2, "d_ff": 256},
|
| 548 |
+
"large": {"d_model": 256, "num_heads": 8, "num_layers": 4, "d_ff": 512},
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
# Directory to save trained model and related artifacts (checkpoint, vocab,
|
| 552 |
+
# and configuration dictionary for reproducibility).
|
| 553 |
+
# Keep output paths relative to the current working directory.
|
| 554 |
+
SAVE_DIR = os.path.join(".", "saved_model")
|
| 555 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 556 |
+
MODEL_PATH = os.path.join(SAVE_DIR, "transformer_imdb.pt")
|
| 557 |
+
REPORT_PATH = os.path.join(SAVE_DIR, "transformer_imdb_experiment_report.md")
|
| 558 |
+
|
| 559 |
+
def main():
|
| 560 |
+
"""
|
| 561 |
+
Train a Transformer-based sentiment classifier on IMDB and save the model,
|
| 562 |
+
vocabulary, and configuration to disk.
|
| 563 |
+
"""
|
| 564 |
+
# 1) Load IMDB training split and then create train/validation split.
|
| 565 |
+
all_train_texts, all_train_labels = load_imdb_texts(split="train")
|
| 566 |
+
|
| 567 |
+
train_texts, val_texts, train_labels, val_labels = train_test_split(
|
| 568 |
+
all_train_texts,
|
| 569 |
+
all_train_labels,
|
| 570 |
+
test_size=0.2,
|
| 571 |
+
random_state=42,
|
| 572 |
+
stratify=all_train_labels,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
# 2) Build vocabulary using training texts only (avoid validation leakage).
|
| 576 |
+
vocab = build_vocab(train_texts, MAX_VOCAB)
|
| 577 |
+
|
| 578 |
+
# 3) Preprocess train and validation data into fixed-length ID sequences.
|
| 579 |
+
train_sequences = preprocess_data(train_texts, vocab, MAX_LEN)
|
| 580 |
+
val_sequences = preprocess_data(val_texts, vocab, MAX_LEN)
|
| 581 |
+
|
| 582 |
+
train_dataset = IMDBDataset(train_sequences, train_labels)
|
| 583 |
+
val_dataset = IMDBDataset(val_sequences, val_labels)
|
| 584 |
+
|
| 585 |
+
# DataLoaders for mini-batch training and validation.
|
| 586 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 587 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
| 588 |
+
|
| 589 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 590 |
+
results = []
|
| 591 |
+
|
| 592 |
+
# Train and evaluate multiple model sizes to analyze how complexity
|
| 593 |
+
# changes sentiment-classification performance.
|
| 594 |
+
for size_name, size_cfg in MODEL_SIZES.items():
|
| 595 |
+
print("\n" + "=" * 72)
|
| 596 |
+
print(f"Training {size_name.upper()} model with config: {size_cfg}")
|
| 597 |
+
print("=" * 72)
|
| 598 |
+
|
| 599 |
+
model = TransformerClassifier(
|
| 600 |
+
len(vocab),
|
| 601 |
+
size_cfg["d_model"],
|
| 602 |
+
size_cfg["num_heads"],
|
| 603 |
+
size_cfg["num_layers"],
|
| 604 |
+
size_cfg["d_ff"],
|
| 605 |
+
MAX_LEN,
|
| 606 |
+
)
|
| 607 |
+
param_count = count_trainable_parameters(model)
|
| 608 |
+
print(f"Trainable parameters ({size_name}): {param_count:,}")
|
| 609 |
+
|
| 610 |
+
train_model(model, train_loader, val_loader, EPOCHS, LR, device)
|
| 611 |
+
val_metrics = evaluate_model(model, val_loader, device)
|
| 612 |
+
size_model_path = os.path.join(SAVE_DIR, f"transformer_imdb_{size_name}.pt")
|
| 613 |
+
results.append(
|
| 614 |
+
{
|
| 615 |
+
"size": size_name,
|
| 616 |
+
"params": param_count,
|
| 617 |
+
"config": size_cfg,
|
| 618 |
+
"metrics": val_metrics,
|
| 619 |
+
"checkpoint_path": size_model_path,
|
| 620 |
+
}
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
# Save each trained size-specific model.
|
| 624 |
+
torch.save(
|
| 625 |
+
{
|
| 626 |
+
"model_state_dict": model.state_dict(),
|
| 627 |
+
"vocab": vocab,
|
| 628 |
+
"config": {
|
| 629 |
+
"max_vocab": MAX_VOCAB,
|
| 630 |
+
"max_len": MAX_LEN,
|
| 631 |
+
"batch_size": BATCH_SIZE,
|
| 632 |
+
"epochs": EPOCHS,
|
| 633 |
+
"lr": LR,
|
| 634 |
+
"size_name": size_name,
|
| 635 |
+
**size_cfg,
|
| 636 |
+
},
|
| 637 |
+
"val_metrics": val_metrics,
|
| 638 |
+
},
|
| 639 |
+
size_model_path,
|
| 640 |
+
)
|
| 641 |
+
print(f"Saved {size_name} model to {size_model_path}")
|
| 642 |
+
|
| 643 |
+
# Print a concise comparison table at the end.
|
| 644 |
+
print("\n" + "#" * 72)
|
| 645 |
+
print("Model Size Impact Summary (Validation Set)")
|
| 646 |
+
print("#" * 72)
|
| 647 |
+
print(f"{'Size':<10} {'Params':>12} {'Acc':>8} {'Precision':>10} {'Recall':>8} {'F1':>8}")
|
| 648 |
+
for item in results:
|
| 649 |
+
m = item["metrics"]
|
| 650 |
+
print(
|
| 651 |
+
f"{item['size']:<10} "
|
| 652 |
+
f"{item['params']:>12,} "
|
| 653 |
+
f"{m['accuracy']:>8.4f} "
|
| 654 |
+
f"{m['precision']:>10.4f} "
|
| 655 |
+
f"{m['recall']:>8.4f} "
|
| 656 |
+
f"{m['f1']:>8.4f}"
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# Keep a compatibility checkpoint name for the best model by validation F1.
|
| 660 |
+
best_result = max(results, key=lambda x: x["metrics"]["f1"])
|
| 661 |
+
best_model_path = os.path.join(SAVE_DIR, f"transformer_imdb_{best_result['size']}.pt")
|
| 662 |
+
torch.save(
|
| 663 |
+
{
|
| 664 |
+
"best_size": best_result["size"],
|
| 665 |
+
"best_model_path": best_model_path,
|
| 666 |
+
"all_results": results,
|
| 667 |
+
},
|
| 668 |
+
MODEL_PATH,
|
| 669 |
+
)
|
| 670 |
+
print(f"\nBest model by Val F1: {best_result['size']} -> {best_model_path}")
|
| 671 |
+
print(f"Experiment summary saved to {MODEL_PATH}")
|
| 672 |
+
|
| 673 |
+
write_experiment_report_md(
|
| 674 |
+
REPORT_PATH,
|
| 675 |
+
results,
|
| 676 |
+
best_result,
|
| 677 |
+
device,
|
| 678 |
+
train_size=len(train_texts),
|
| 679 |
+
val_size=len(val_texts),
|
| 680 |
+
)
|
| 681 |
+
print(f"Markdown report saved to {REPORT_PATH}")
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
if __name__ == "__main__":
|
| 685 |
+
main()
|
assignment_llm_1/assignment_text/code/c1_analysis.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation and qualitative error analysis helpers for the IMDB Transformer model.
|
| 3 |
+
|
| 4 |
+
This module is separate from `c1.py` and focuses only on:
|
| 5 |
+
- Loading a previously trained model from disk.
|
| 6 |
+
- Evaluating it on an IMDB split.
|
| 7 |
+
- Inspecting misclassified examples for qualitative error analysis.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Dict, List, Tuple
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
|
| 19 |
+
from c1 import (
|
| 20 |
+
IMDBDataset,
|
| 21 |
+
TransformerClassifier,
|
| 22 |
+
preprocess_data,
|
| 23 |
+
evaluate_model,
|
| 24 |
+
load_imdb_texts,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Keep output/checkpoint paths relative to the current working directory.
|
| 28 |
+
SAVE_DIR = os.path.join(".", "saved_model")
|
| 29 |
+
MODEL_PATH = os.path.join(SAVE_DIR, "transformer_imdb.pt")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def analyze_misclassifications_on_texts(
|
| 33 |
+
model: torch.nn.Module,
|
| 34 |
+
texts: List[str],
|
| 35 |
+
labels: List[int],
|
| 36 |
+
vocab: Dict[str, int],
|
| 37 |
+
max_len: int,
|
| 38 |
+
device: torch.device,
|
| 39 |
+
num_examples: int = 5,
|
| 40 |
+
) -> None:
|
| 41 |
+
"""
|
| 42 |
+
Inspect concrete examples where the model makes mistakes to understand
|
| 43 |
+
*why* it fails and how to improve it.
|
| 44 |
+
|
| 45 |
+
How to read the output (practical guidance):
|
| 46 |
+
- Start with the true vs. predicted label:
|
| 47 |
+
- For each misclassified review, ask whether the ground-truth label
|
| 48 |
+
actually matches the human-intuitive sentiment. Occasional noisy
|
| 49 |
+
labels are common in IMDB-style datasets.
|
| 50 |
+
- Look at the confidence vector:
|
| 51 |
+
- Very confident but wrong predictions often indicate *systematic bias*
|
| 52 |
+
(e.g., the model over-trusts certain keywords like "great", "worst").
|
| 53 |
+
- Low-confidence errors may simply reflect inherently ambiguous reviews.
|
| 54 |
+
- Scan the text content:
|
| 55 |
+
- Check for **rare or domain-specific words** (brand names, slang,
|
| 56 |
+
technical jargon) that might not appear often enough in training.
|
| 57 |
+
- Look for **negation patterns** ("not good", "hardly bad", "no longer
|
| 58 |
+
terrible") where bag-of-words style cues can mislead attention.
|
| 59 |
+
- Notice **mixed sentiment** or **topic vs. opinion** separation
|
| 60 |
+
(e.g., long plot summary plus a brief opinion at the end).
|
| 61 |
+
- Pay attention to **sarcasm and irony**, which are notoriously hard
|
| 62 |
+
for models relying mostly on local lexical cues.
|
| 63 |
+
- Compare several misclassified examples:
|
| 64 |
+
- If you see many errors with long reviews, consider increasing MAX_LEN
|
| 65 |
+
or using a deeper model.
|
| 66 |
+
- If errors cluster around subtle, low-intensity sentiment, you may need
|
| 67 |
+
more expressive capacity (higher d_model / more layers) or additional
|
| 68 |
+
training data.
|
| 69 |
+
|
| 70 |
+
Based on these observations you can propose targeted improvements, such as:
|
| 71 |
+
- Expanding the vocabulary or switching to subword tokenization.
|
| 72 |
+
- Adjusting hyperparameters (sequence length, model size).
|
| 73 |
+
- Incorporating pre-trained language models for richer semantics.
|
| 74 |
+
"""
|
| 75 |
+
model.eval()
|
| 76 |
+
sequences = preprocess_data(texts, vocab, max_len)
|
| 77 |
+
dataset = IMDBDataset(sequences, labels)
|
| 78 |
+
loader = DataLoader(dataset, batch_size=64, shuffle=False)
|
| 79 |
+
|
| 80 |
+
printed = 0
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
for batch_idx, (batch_seq, batch_lab) in enumerate(loader):
|
| 83 |
+
batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device)
|
| 84 |
+
logits = model(batch_seq)
|
| 85 |
+
probs = F.softmax(logits, dim=1)
|
| 86 |
+
preds = torch.argmax(probs, dim=1)
|
| 87 |
+
|
| 88 |
+
start = batch_idx * loader.batch_size
|
| 89 |
+
end = start + batch_seq.size(0)
|
| 90 |
+
batch_texts = texts[start:end]
|
| 91 |
+
|
| 92 |
+
for text, true_y, pred_y, prob_vec in zip(
|
| 93 |
+
batch_texts,
|
| 94 |
+
batch_lab.cpu().numpy(),
|
| 95 |
+
preds.cpu().numpy(),
|
| 96 |
+
probs.cpu().numpy(),
|
| 97 |
+
):
|
| 98 |
+
if true_y != pred_y:
|
| 99 |
+
printed += 1
|
| 100 |
+
print("=" * 80)
|
| 101 |
+
print(f"Misclassified example #{printed}")
|
| 102 |
+
print(f"True label : {true_y} (0=neg, 1=pos)")
|
| 103 |
+
print(f"Predicted label: {pred_y}")
|
| 104 |
+
print(f"Model confidence (class 0, class 1): {prob_vec}")
|
| 105 |
+
|
| 106 |
+
if printed >= num_examples:
|
| 107 |
+
print("=" * 80)
|
| 108 |
+
print(
|
| 109 |
+
f"Displayed the first {num_examples} misclassified "
|
| 110 |
+
"examples on this split."
|
| 111 |
+
)
|
| 112 |
+
return
|
| 113 |
+
|
| 114 |
+
if printed == 0:
|
| 115 |
+
print("No misclassified examples found on this split (perfect accuracy).")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def load_trained_model_from_checkpoint(
|
| 119 |
+
checkpoint_path: str = MODEL_PATH,
|
| 120 |
+
device: torch.device | None = None,
|
| 121 |
+
) -> Tuple[torch.nn.Module, Dict[str, int], Dict]:
|
| 122 |
+
"""
|
| 123 |
+
Load a previously trained Transformer model, along with its vocabulary
|
| 124 |
+
and configuration, from the checkpoint saved by `c1.py`.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
model: Loaded TransformerClassifier on the requested device.
|
| 128 |
+
vocab: Token-to-index mapping used during training.
|
| 129 |
+
config: Hyperparameter/config dictionary saved in the checkpoint.
|
| 130 |
+
"""
|
| 131 |
+
if device is None:
|
| 132 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 133 |
+
|
| 134 |
+
ckpt = torch.load(checkpoint_path, map_location=device)
|
| 135 |
+
vocab: Dict[str, int] = ckpt["vocab"]
|
| 136 |
+
config: Dict = ckpt["config"]
|
| 137 |
+
|
| 138 |
+
model = TransformerClassifier(
|
| 139 |
+
vocab_size=len(vocab),
|
| 140 |
+
d_model=config["d_model"],
|
| 141 |
+
num_heads=config["num_heads"],
|
| 142 |
+
num_layers=config["num_layers"],
|
| 143 |
+
d_ff=config["d_ff"],
|
| 144 |
+
max_len=config["max_len"],
|
| 145 |
+
).to(device)
|
| 146 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 147 |
+
model.eval()
|
| 148 |
+
|
| 149 |
+
return model, vocab, config
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def evaluate_and_analyze_saved_model(
|
| 153 |
+
split: str = "test",
|
| 154 |
+
checkpoint_path: str | None = None,
|
| 155 |
+
model_size: str = "medium",
|
| 156 |
+
num_examples: int = 5,
|
| 157 |
+
device: torch.device | None = None,
|
| 158 |
+
) -> None:
|
| 159 |
+
"""
|
| 160 |
+
High-level helper that:
|
| 161 |
+
1) Loads the trained model/vocab/config from disk.
|
| 162 |
+
2) Evaluates it on the requested IMDB split.
|
| 163 |
+
3) Runs qualitative error analysis on that split.
|
| 164 |
+
"""
|
| 165 |
+
if device is None:
|
| 166 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 167 |
+
|
| 168 |
+
if checkpoint_path is None:
|
| 169 |
+
checkpoint_path = os.path.join(SAVE_DIR, f"transformer_imdb_{model_size}.pt")
|
| 170 |
+
|
| 171 |
+
print(f"Loading trained model from: {checkpoint_path}")
|
| 172 |
+
model, vocab, config = load_trained_model_from_checkpoint(
|
| 173 |
+
checkpoint_path=checkpoint_path,
|
| 174 |
+
device=device,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
print(f"Evaluating on IMDB '{split}' split...")
|
| 178 |
+
texts, labels = load_imdb_texts(split=split)
|
| 179 |
+
sequences = preprocess_data(texts, vocab, config["max_len"])
|
| 180 |
+
dataset = IMDBDataset(sequences, labels)
|
| 181 |
+
loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=False)
|
| 182 |
+
|
| 183 |
+
metrics = evaluate_model(model, loader, device)
|
| 184 |
+
print("Evaluation metrics:", metrics)
|
| 185 |
+
|
| 186 |
+
print("\nRunning qualitative error analysis...")
|
| 187 |
+
analyze_misclassifications_on_texts(
|
| 188 |
+
model=model,
|
| 189 |
+
texts=texts,
|
| 190 |
+
labels=labels,
|
| 191 |
+
vocab=vocab,
|
| 192 |
+
max_len=config["max_len"],
|
| 193 |
+
device=device,
|
| 194 |
+
num_examples=num_examples,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def main():
|
| 199 |
+
"""
|
| 200 |
+
Command-line interface for evaluation and analysis utilities.
|
| 201 |
+
|
| 202 |
+
Example:
|
| 203 |
+
# Evaluate medium model on IMDB test split and show 5 errors
|
| 204 |
+
python c1_analysis.py --split test --model_size medium --num_examples 5
|
| 205 |
+
"""
|
| 206 |
+
parser = argparse.ArgumentParser(description="IMDB Transformer evaluation and analysis utilities")
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--split",
|
| 209 |
+
type=str,
|
| 210 |
+
default="test",
|
| 211 |
+
help="IMDB split to evaluate on (e.g., 'test', 'train').",
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--checkpoint",
|
| 215 |
+
type=str,
|
| 216 |
+
default=None,
|
| 217 |
+
help=(
|
| 218 |
+
"Optional explicit checkpoint path. If provided, this overrides "
|
| 219 |
+
"--model_size."
|
| 220 |
+
),
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--model_size",
|
| 224 |
+
type=str,
|
| 225 |
+
choices=["small", "medium", "large"],
|
| 226 |
+
default="medium",
|
| 227 |
+
help=(
|
| 228 |
+
"Model size to load from saved checkpoints. Used when --checkpoint "
|
| 229 |
+
"is not provided."
|
| 230 |
+
),
|
| 231 |
+
)
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--num_examples",
|
| 234 |
+
type=int,
|
| 235 |
+
default=5,
|
| 236 |
+
help="Number of misclassified examples to print in error analysis.",
|
| 237 |
+
)
|
| 238 |
+
args = parser.parse_args()
|
| 239 |
+
|
| 240 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 241 |
+
|
| 242 |
+
evaluate_and_analyze_saved_model(
|
| 243 |
+
split=args.split,
|
| 244 |
+
checkpoint_path=args.checkpoint,
|
| 245 |
+
model_size=args.model_size,
|
| 246 |
+
num_examples=args.num_examples,
|
| 247 |
+
device=device,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
main()
|
| 253 |
+
|
assignment_llm_1/assignment_text/code/explanation_creation.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from typing import Dict, List, Tuple
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
|
| 8 |
+
# Assuming these are in your c1.py
|
| 9 |
+
from c1 import (
|
| 10 |
+
IMDBDataset,
|
| 11 |
+
TransformerClassifier,
|
| 12 |
+
preprocess_data,
|
| 13 |
+
evaluate_model,
|
| 14 |
+
load_imdb_texts,
|
| 15 |
+
MODEL_PATH,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# You would need to install openai: pip install openai
|
| 19 |
+
from openai import OpenAI
|
| 20 |
+
api_file = "/home/mshahidul/api_new.json"
|
| 21 |
+
with open(api_file, "r") as f:
|
| 22 |
+
api_keys = json.load(f)
|
| 23 |
+
openai_api_key = api_keys["openai"]
|
| 24 |
+
|
| 25 |
+
client = OpenAI(api_key=openai_api_key)
|
| 26 |
+
# Initialize your client (ensure your API key is in your environment variables)
|
| 27 |
+
|
| 28 |
+
def get_llm_explanation(review_text: str, true_y: int, pred_y: int) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Uses an LLM to perform qualitative reasoning on why the model failed.
|
| 31 |
+
"""
|
| 32 |
+
sentiment = {0: "Negative", 1: "Positive"}
|
| 33 |
+
|
| 34 |
+
prompt = f"""
|
| 35 |
+
A Transformer model misclassified the following movie review.
|
| 36 |
+
|
| 37 |
+
REVIEW: "{review_text[:1000]}"
|
| 38 |
+
TRUE LABEL: {sentiment[true_y]}
|
| 39 |
+
MODEL PREDICTED: {sentiment[pred_y]}
|
| 40 |
+
|
| 41 |
+
Task: Provide a concise (2-3 sentence) explanation of why a machine learning
|
| 42 |
+
model might have struggled with this specific text. Mention linguistic
|
| 43 |
+
features like sarcasm, double negatives, mixed sentiment, or specific keywords.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
response = client.chat.completions.create(
|
| 48 |
+
model="gpt-4o-mini", # Using 4o-mini as a high-performance proxy for "mini" models
|
| 49 |
+
messages=[{"role": "user", "content": prompt}],
|
| 50 |
+
temperature=0.2
|
| 51 |
+
)
|
| 52 |
+
return response.choices[0].message.content.strip()
|
| 53 |
+
except Exception as e:
|
| 54 |
+
return f"LLM Analysis failed: {str(e)}"
|
| 55 |
+
|
| 56 |
+
def analyze_misclassifications_on_texts(
|
| 57 |
+
model: torch.nn.Module,
|
| 58 |
+
texts: List[str],
|
| 59 |
+
labels: List[int],
|
| 60 |
+
vocab: Dict[str, int],
|
| 61 |
+
max_len: int,
|
| 62 |
+
device: torch.device,
|
| 63 |
+
num_examples: int = 10,
|
| 64 |
+
) -> List[Dict]:
|
| 65 |
+
"""
|
| 66 |
+
Identifies errors, generates LLM explanations, and returns structured results.
|
| 67 |
+
"""
|
| 68 |
+
model.eval()
|
| 69 |
+
sequences = preprocess_data(texts, vocab, max_len)
|
| 70 |
+
dataset = IMDBDataset(sequences, labels)
|
| 71 |
+
loader = DataLoader(dataset, batch_size=64, shuffle=False)
|
| 72 |
+
|
| 73 |
+
error_results = []
|
| 74 |
+
printed = 0
|
| 75 |
+
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
for batch_idx, (batch_seq, batch_lab) in enumerate(loader):
|
| 78 |
+
batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device)
|
| 79 |
+
logits = model(batch_seq)
|
| 80 |
+
probs = F.softmax(logits, dim=1)
|
| 81 |
+
preds = torch.argmax(probs, dim=1)
|
| 82 |
+
|
| 83 |
+
start = batch_idx * loader.batch_size
|
| 84 |
+
batch_texts = texts[start:start + batch_seq.size(0)]
|
| 85 |
+
|
| 86 |
+
for text, true_y, pred_y, prob_vec in zip(
|
| 87 |
+
batch_texts,
|
| 88 |
+
batch_lab.cpu().numpy(),
|
| 89 |
+
preds.cpu().numpy(),
|
| 90 |
+
probs.cpu().numpy(),
|
| 91 |
+
):
|
| 92 |
+
if true_y != pred_y:
|
| 93 |
+
printed += 1
|
| 94 |
+
|
| 95 |
+
print(f"Analyzing error #{printed} with LLM...")
|
| 96 |
+
explanation = get_llm_explanation(text, true_y, pred_y)
|
| 97 |
+
|
| 98 |
+
error_entry = {
|
| 99 |
+
"example_id": printed,
|
| 100 |
+
"true_label": int(true_y),
|
| 101 |
+
"predicted_label": int(pred_y),
|
| 102 |
+
"confidence_neg": float(prob_vec[0]),
|
| 103 |
+
"confidence_pos": float(prob_vec[1]),
|
| 104 |
+
"text": text,
|
| 105 |
+
"explanation": explanation
|
| 106 |
+
}
|
| 107 |
+
error_results.append(error_entry)
|
| 108 |
+
|
| 109 |
+
# Print to console for immediate feedback
|
| 110 |
+
print("=" * 80)
|
| 111 |
+
print(f"True: {true_y} | Pred: {pred_y}")
|
| 112 |
+
print(f"Reasoning: {explanation}")
|
| 113 |
+
print("=" * 80)
|
| 114 |
+
|
| 115 |
+
if printed >= num_examples:
|
| 116 |
+
return error_results
|
| 117 |
+
|
| 118 |
+
return error_results
|
| 119 |
+
|
| 120 |
+
def load_trained_model_from_checkpoint(
|
| 121 |
+
checkpoint_path: str = MODEL_PATH,
|
| 122 |
+
device: torch.device | None = None,
|
| 123 |
+
) -> Tuple[torch.nn.Module, Dict[str, int], Dict]:
|
| 124 |
+
if device is None:
|
| 125 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 126 |
+
|
| 127 |
+
ckpt = torch.load(checkpoint_path, map_location=device)
|
| 128 |
+
vocab = ckpt["vocab"]
|
| 129 |
+
config = ckpt["config"]
|
| 130 |
+
|
| 131 |
+
model = TransformerClassifier(
|
| 132 |
+
vocab_size=len(vocab),
|
| 133 |
+
d_model=config["d_model"],
|
| 134 |
+
num_heads=config["num_heads"],
|
| 135 |
+
num_layers=config["num_layers"],
|
| 136 |
+
d_ff=config["d_ff"],
|
| 137 |
+
max_len=config["max_len"],
|
| 138 |
+
).to(device)
|
| 139 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 140 |
+
return model, vocab, config
|
| 141 |
+
|
| 142 |
+
def main():
|
| 143 |
+
parser = argparse.ArgumentParser()
|
| 144 |
+
parser.add_argument("--split", type=str, default="test")
|
| 145 |
+
parser.add_argument("--num_examples", type=int, default=10)
|
| 146 |
+
parser.add_argument("--output", type=str, default="error_analysis.json")
|
| 147 |
+
args = parser.parse_args()
|
| 148 |
+
|
| 149 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 150 |
+
|
| 151 |
+
# 1. Load Model
|
| 152 |
+
model, vocab, config = load_trained_model_from_checkpoint(device=device)
|
| 153 |
+
|
| 154 |
+
# 2. Load Data
|
| 155 |
+
texts, labels = load_imdb_texts(split=args.split)
|
| 156 |
+
|
| 157 |
+
# 3. Analyze
|
| 158 |
+
errors = analyze_misclassifications_on_texts(
|
| 159 |
+
model=model,
|
| 160 |
+
texts=texts,
|
| 161 |
+
labels=labels,
|
| 162 |
+
vocab=vocab,
|
| 163 |
+
max_len=config["max_len"],
|
| 164 |
+
device=device,
|
| 165 |
+
num_examples=args.num_examples
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# 4. Save Results
|
| 169 |
+
with open(args.output, "w") as f:
|
| 170 |
+
json.dump(errors, f, indent=4)
|
| 171 |
+
print(f"\nAnalysis complete. Results saved to {args.output}")
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
main()
|
assignment_llm_1/assignment_text/documentation/different_model_size_and_performance.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# IMDB Transformer Model-Size Experiment Report
|
| 2 |
+
|
| 3 |
+
- Generated at: `2026-02-10 09:52:58`
|
| 4 |
+
- Device: `cuda`
|
| 5 |
+
- Training samples: `20000`
|
| 6 |
+
- Validation samples: `5000`
|
| 7 |
+
- Max vocab size: `5000`
|
| 8 |
+
- Max sequence length: `64`
|
| 9 |
+
- Batch size: `32`
|
| 10 |
+
- Epochs: `5`
|
| 11 |
+
- Learning rate: `0.001`
|
| 12 |
+
|
| 13 |
+
## Overall Comparison
|
| 14 |
+
|
| 15 |
+
| Model Size | Trainable Params | Accuracy | Precision | Recall | F1 | Checkpoint |
|
| 16 |
+
|---|---:|---:|---:|---:|---:|---|
|
| 17 |
+
| small | 353,602 | 0.7816 | 0.7832 | 0.7788 | 0.7810 | `assignment_llm_1/saved_model/transformer_imdb_small.pt` |
|
| 18 |
+
| medium | 905,218 | 0.7874 | 0.7948 | 0.7748 | 0.7847 | `assignment_llm_1/saved_model/transformer_imdb_medium.pt` |
|
| 19 |
+
| large | 3,388,930 | 0.7374 | 0.7392 | 0.7336 | 0.7364 | `assignment_llm_1/saved_model/transformer_imdb_large.pt` |
|
| 20 |
+
|
| 21 |
+
## Best Model
|
| 22 |
+
|
| 23 |
+
- Best size by validation F1: `medium`
|
| 24 |
+
- Checkpoint: `assignment_llm_1/saved_model/transformer_imdb_medium.pt`
|
| 25 |
+
- Trainable parameters: `905,218`
|
| 26 |
+
- Metrics:
|
| 27 |
+
- Accuracy: `0.7874`
|
| 28 |
+
- Precision: `0.7948`
|
| 29 |
+
- Recall: `0.7748`
|
| 30 |
+
- F1: `0.7847`
|
| 31 |
+
|
| 32 |
+
## Per-Model Details
|
| 33 |
+
|
| 34 |
+
### Small model
|
| 35 |
+
|
| 36 |
+
- Architecture:
|
| 37 |
+
- `d_model`: `64`
|
| 38 |
+
- `num_heads`: `4`
|
| 39 |
+
- `num_layers`: `1`
|
| 40 |
+
- `d_ff`: `128`
|
| 41 |
+
- Trainable params: `353,602`
|
| 42 |
+
- Checkpoint: `/assignment_llm_1/saved_model/transformer_imdb_small.pt`
|
| 43 |
+
- Validation metrics:
|
| 44 |
+
- Accuracy: `0.7816`
|
| 45 |
+
- Precision: `0.7832`
|
| 46 |
+
- Recall: `0.7788`
|
| 47 |
+
- F1: `0.7810`
|
| 48 |
+
|
| 49 |
+
### Medium model
|
| 50 |
+
|
| 51 |
+
- Architecture:
|
| 52 |
+
- `d_model`: `128`
|
| 53 |
+
- `num_heads`: `8`
|
| 54 |
+
- `num_layers`: `2`
|
| 55 |
+
- `d_ff`: `256`
|
| 56 |
+
- Trainable params: `905,218`
|
| 57 |
+
- Checkpoint: `/assignment_llm_1/saved_model/transformer_imdb_medium.pt`
|
| 58 |
+
- Validation metrics:
|
| 59 |
+
- Accuracy: `0.7874`
|
| 60 |
+
- Precision: `0.7948`
|
| 61 |
+
- Recall: `0.7748`
|
| 62 |
+
- F1: `0.7847`
|
| 63 |
+
|
| 64 |
+
### Large model
|
| 65 |
+
|
| 66 |
+
- Architecture:
|
| 67 |
+
- `d_model`: `256`
|
| 68 |
+
- `num_heads`: `8`
|
| 69 |
+
- `num_layers`: `4`
|
| 70 |
+
- `d_ff`: `512`
|
| 71 |
+
- Trainable params: `3,388,930`
|
| 72 |
+
- Checkpoint: `/assignment_llm_1/saved_model/transformer_imdb_large.pt`
|
| 73 |
+
- Validation metrics:
|
| 74 |
+
- Accuracy: `0.7374`
|
| 75 |
+
- Precision: `0.7392`
|
| 76 |
+
- Recall: `0.7336`
|
| 77 |
+
- F1: `0.7364`
|
assignment_llm_1/assignment_text/documentation/documentation.md
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Homework 1 (Part I) – Transformer-Based Sentiment Analysis on IMDB
|
| 2 |
+
|
| 3 |
+
### 1. Introduction
|
| 4 |
+
|
| 5 |
+
This project implements a Transformer-based neural network for binary sentiment analysis on movie reviews. Given a raw text review from the IMDB dataset, the model predicts whether the review expresses a positive or negative sentiment. Sentiment analysis is a fundamental task in natural language processing (NLP) with applications in opinion mining, recommendation systems, and social media monitoring, where understanding users’ attitudes at scale is essential.
|
| 6 |
+
|
| 7 |
+
The Transformer architecture is well suited to this task because it relies on self-attention rather than recurrent connections. Self-attention allows the model to capture long-range dependencies between words regardless of their distance in the sequence and to focus on sentiment-bearing tokens (for example, “not”, “excellent”, “boring”) even when they appear far apart. This makes Transformers more effective and easier to train in parallel than traditional RNN or LSTM models for document-level classification.
|
| 8 |
+
|
| 9 |
+
### 2. Dataset Description
|
| 10 |
+
|
| 11 |
+
The model is trained and validated on the IMDB movie reviews dataset, loaded via the HuggingFace `datasets` library using the `"imdb"` configuration. This dataset corresponds to the Large Movie Review Dataset introduced by Maas et al. (2011), which contains 50,000 labeled movie reviews (25,000 train and 25,000 test) plus additional unlabeled reviews. The original dataset is available from the Stanford AI Lab at `http://ai.stanford.edu/~amaas/data/sentiment/`. Each labeled review is annotated with a binary label: 0 for negative and 1 for positive sentiment.
|
| 12 |
+
|
| 13 |
+
In this implementation, I use the labeled training split exposed by the HuggingFace wrapper and perform an additional 80/20 train/validation split with stratification over the labels. This results in approximately 20,000 reviews for training and 5,000 reviews for validation, preserving the original class balance. The official test split and the unlabeled reviews from the Large Movie Review Dataset are not used in this assignment. The labels are integers taking values in {0, 1}, corresponding to negative and positive sentiment respectively.
|
| 14 |
+
|
| 15 |
+
The raw text consists of HTML-formatted movie reviews. As a minimal preprocessing step, HTML line breaks (`<br />`) are removed and all non-alphanumeric characters (except whitespace) are filtered out. No additional filtering or subsetting (such as length-based pruning) is performed; instead, sequence length is controlled later during token-to-index mapping and padding.
|
| 16 |
+
|
| 17 |
+
When using this dataset, the canonical citation requested by the authors is:
|
| 18 |
+
|
| 19 |
+
Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, and Christopher Potts. 2011. Learning Word Vectors for Sentiment Analysis. In Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies, pages 142–150.
|
| 20 |
+
|
| 21 |
+
### 3. Data Preprocessing
|
| 22 |
+
|
| 23 |
+
Preprocessing is implemented in three main stages: tokenization, vocabulary construction, and sequence padding.
|
| 24 |
+
|
| 25 |
+
Tokenization uses a simple regular expression–based tokenizer. The text is converted to lowercase, HTML line breaks are replaced by spaces, and all characters that are not letters, digits, or whitespace are removed. The cleaned string is then split on whitespace to obtain a list of tokens. This approach yields a plain word-level tokenization that is easy to interpret and sufficient for this assignment.
|
| 26 |
+
|
| 27 |
+
Vocabulary construction builds a word-to-index mapping from the training texts only, which avoids information leakage from the validation set. A `Counter` is used to accumulate token frequencies over all training reviews. Two special tokens are reserved: `<PAD>` with index 0 and `<UNK>` with index 1. The remaining slots (up to `MAX_VOCAB = 5000`) are filled with the most frequent words in the training corpus. Any word outside this vocabulary is mapped to the `<UNK>` index at inference time.
|
| 28 |
+
|
| 29 |
+
For sequence padding and truncation, each review is converted to a fixed-length sequence of token IDs with `MAX_LEN = 64`. Tokens are mapped to their integer indices via the vocabulary; if a token is not present, it is mapped to `<UNK>`. If a sequence is shorter than 64 tokens, it is padded on the right with `<PAD>` tokens until length 64. If it is longer, it is truncated to the first 64 tokens. This produces a dense matrix of shape (number_of_examples, 64), which is used as input to the model.
|
| 30 |
+
|
| 31 |
+
### 4. Model Architecture
|
| 32 |
+
|
| 33 |
+
At a high level, the model is a Transformer encoder stacked on top of learned token embeddings and sinusoidal positional encodings, followed by a global pooling operation and a linear classification head. It is a pure encoder architecture tailored for sequence-level classification rather than sequence-to-sequence generation.
|
| 34 |
+
|
| 35 |
+
First, input token IDs are passed through an embedding layer that maps each vocabulary index to a continuous vector of dimension `d_model = 128`. These token embeddings capture distributional semantics learned during training. Since embeddings alone do not encode the order of tokens, the model adds positional encodings. A standard sinusoidal positional encoding is precomputed for positions up to a maximum length and added elementwise to the token embeddings. This provides the model with information about the absolute position of each token in the sequence.
|
| 36 |
+
|
| 37 |
+
The resulting sequence of encoded token representations is processed by a stack of Transformer encoder layers. In this implementation, there are `NUM_LAYERS = 2` such layers, each consisting of a multi-head self-attention sublayer followed by a position-wise feed-forward network, both wrapped in residual connections and layer normalization.
|
| 38 |
+
|
| 39 |
+
Within each encoder block, multi-head self-attention projects the input sequence into queries, keys, and values of dimension `d_k = d_model / num_heads`, where `num_heads = 8`. For each head, attention weights are computed as scaled dot products between queries and keys, divided by the square root of `d_k` to ensure stable gradients. A softmax is then applied to obtain a probability distribution over positions, and a weighted sum of the value vectors is taken. The outputs of all heads are concatenated and passed through a final linear projection back to `d_model`. Self-attention lets each token representation attend to all other tokens in the sequence, enabling the model to capture context such as negation and long-distance sentiment cues.
|
| 40 |
+
|
| 41 |
+
The output of the attention sublayer is added back to the original input via a residual connection and normalized using layer normalization. A feed-forward sublayer then applies a two-layer MLP with hidden dimension `d_ff = 256` and ReLU activation, followed again by residual addition and layer normalization. This combination allows the model to perform both contextual mixing (via attention) and local non-linear transformations (via the feed-forward network).
|
| 42 |
+
|
| 43 |
+
After the final encoder layer, the model applies global average pooling over the sequence dimension, computing the mean of the hidden states across all time steps. This yields a single `d_model`-dimensional vector that summarizes the entire review. A dropout layer with dropout probability 0.1 is applied during training to reduce overfitting. Finally, a linear classification head maps this pooled representation to `num_classes = 2` logits, corresponding to negative and positive sentiment. The predicted class is obtained via an argmax over these logits.
|
| 44 |
+
|
| 45 |
+
In summary, self-attention in this model works by dynamically weighting each token’s contribution to every other token’s representation based on learned similarity scores. Tokens with stronger semantic or syntactic relevance to a given position receive higher attention weights, enabling the model to focus on the most informative parts of the review for sentiment prediction.
|
| 46 |
+
|
| 47 |
+
### 5. Training Pipeline
|
| 48 |
+
|
| 49 |
+
The training loop is implemented around a standard supervised learning pipeline for PyTorch models. The loss function used is cross-entropy loss (`nn.CrossEntropyLoss`), which is appropriate for multi-class (here, binary) classification with mutually exclusive classes. It directly optimizes the log-likelihood of the correct label given the model’s logits.
|
| 50 |
+
|
| 51 |
+
Optimization is performed using the Adam optimizer (`torch.optim.Adam`) with an initial learning rate of `LR = 0.001`. Adam is chosen for its robustness to gradient scaling and its ability to adaptively tune learning rates per parameter, which is beneficial when training Transformer-style models. To encourage better convergence, a learning rate scheduler (`StepLR`) is applied: the learning rate is multiplied by `gamma = 0.5` every `step_size = 2` epochs. This gradually reduces the learning rate as training progresses, helping the model to fine-tune around a local optimum.
|
| 52 |
+
|
| 53 |
+
Batches are constructed using `DataLoader` objects with a batch size of `BATCH_SIZE = 32`. The training loader shuffles the data at each epoch, while the validation loader does not. The model is trained for `EPOCHS = 5` full passes over the training set. Dropout with probability 0.1 is applied to the embeddings and intermediate representations, providing regularization by randomly masking components during training. No explicit early stopping is implemented in code; instead, performance is monitored on the validation set after each epoch via evaluation metrics. If desired, early stopping could be implemented externally by tracking the best validation F1 score and halting when it stops improving.
|
| 54 |
+
|
| 55 |
+
### 6. Evaluation Metrics
|
| 56 |
+
|
| 57 |
+
The evaluation function computes four metrics on the validation (or test) data: accuracy, precision, recall, and F1-score. Predictions are obtained by taking the argmax over the model’s logits for each example.
|
| 58 |
+
|
| 59 |
+
Accuracy measures the overall proportion of correctly classified reviews and is a natural baseline metric for binary sentiment classification. However, accuracy alone can be misleading if the dataset is imbalanced or if different error types have different costs. Therefore, precision, recall, and F1-score (with binary averaging) are also reported. Precision captures the fraction of predicted positive reviews that are truly positive, recall captures the fraction of true positive reviews that are correctly identified, and the F1-score is the harmonic mean of precision and recall. For sentiment analysis where both false positives and false negatives are important, F1-score provides a more informative single-number summary than accuracy alone.
|
| 60 |
+
|
| 61 |
+
### 7. Experimental Results (Updated with Model-Size Study)
|
| 62 |
+
|
| 63 |
+
The training script now explicitly explores model complexity by training three Transformer sizes:
|
| 64 |
+
|
| 65 |
+
- `small`: `d_model=64`, `num_heads=4`, `num_layers=1`, `d_ff=128`
|
| 66 |
+
- `medium`: `d_model=128`, `num_heads=8`, `num_layers=2`, `d_ff=256`
|
| 67 |
+
- `large`: `d_model=256`, `num_heads=8`, `num_layers=4`, `d_ff=512`
|
| 68 |
+
|
| 69 |
+
For each size, the code reports training loss and validation metrics (accuracy, precision, recall, F1-score). It also computes trainable parameter counts and prints a final comparison table so performance can be compared against model complexity in a single run.
|
| 70 |
+
|
| 71 |
+
In addition to console logging, the script saves a human-readable report:
|
| 72 |
+
|
| 73 |
+
- `assignment_text/saved_model/transformer_imdb_experiment_report.md`
|
| 74 |
+
|
| 75 |
+
This report includes run metadata, per-size architecture details, checkpoint paths, parameter counts, and metric summaries, making the experiment easy to share and inspect.
|
| 76 |
+
|
| 77 |
+
### 8. Model Analysis and Error Inspection
|
| 78 |
+
|
| 79 |
+
The main strength of this model is its ability to exploit contextual information across the entire review via self-attention. Unlike bag-of-words or simple n-gram models, it can represent interactions between distant tokens, such as handling negation (“not good”) or nuanced phrases that influence sentiment only in context. Additionally, the use of positional encoding preserves word order, which is crucial for distinguishing, for example, “not only good but great” from superficially similar but differently ordered phrases.
|
| 80 |
+
|
| 81 |
+
However, several weaknesses are apparent. First, the tokenizer is very simple and does not handle subword units, morphology, or out-of-vocabulary words in a sophisticated way. Rare words are collapsed into the `<UNK>` token, which can obscure important sentiment cues. Second, reviews are truncated to 64 tokens; for very long reviews, important information appearing later in the text may be discarded, potentially harming performance. Third, the model capacity (two encoder layers with `d_model = 128`) is modest compared to modern large-scale Transformers and may limit performance on more challenging or nuanced reviews.
|
| 82 |
+
|
| 83 |
+
From an overfitting perspective, validation metrics track training performance reasonably well, and the use of dropout and a relatively small model size helps keep overfitting under control. That said, with only 5 epochs of training, there is also a risk of underfitting: the model might not fully exploit the available training data. Adding more epochs with careful learning rate scheduling or early stopping based on validation F1 could further improve performance.
|
| 84 |
+
|
| 85 |
+
Potential improvements include increasing the maximum sequence length to capture more of each review, adopting a more advanced tokenizer such as Byte-Pair Encoding (BPE) or WordPiece to better handle rare words and subwords, and scaling up the model (more layers, larger `d_model`, larger `d_ff`) if computational resources allow. Pretraining on a large unlabeled corpus and then fine-tuning on IMDB would also likely yield significant gains, following the standard transfer learning paradigm for NLP.
|
| 86 |
+
|
| 87 |
+
For concrete examples, error-analysis instances are stored in:
|
| 88 |
+
|
| 89 |
+
- `assignment_text/documentation/error_analysis.json`
|
| 90 |
+
|
| 91 |
+
Based on the current file (`10` sampled misclassified reviews), the dominant failure mode is clear: all listed mistakes are **false positives** where the true label is negative (`0`) but the model predicts positive (`1`). This indicates the model is comparatively weaker at recognizing subtle or mixed negative sentiment.
|
| 92 |
+
|
| 93 |
+
Common patterns observed in these errors include:
|
| 94 |
+
|
| 95 |
+
- **Mixed sentiment with mild praise + stronger criticism**: e.g., "worth a rental" or "semi-alright action" appears early, but overall judgment is negative.
|
| 96 |
+
- **Sarcasm and irony**: phrases such as "only thing good..." or "top of the garbage list" are difficult for the model to interpret reliably.
|
| 97 |
+
- **Contrastive structure**: reviews that begin with positive setup and then pivot to criticism are often classified as positive.
|
| 98 |
+
- **Soft negative wording**: "average", "passable", "not really worth watching again" can be semantically negative but lexically less explicit.
|
| 99 |
+
- **Calibration issue on borderline negatives**: several wrong predictions have moderate confidence for the positive class, suggesting uncertain decision boundaries around nuanced negative reviews.
|
| 100 |
+
|
| 101 |
+
This error profile suggests improvements should prioritize better handling of nuanced negative language (especially sarcasm and contrastive discourse), for example via richer tokenization (subword methods), longer context windows, and additional hard-negative examples during training.
|
| 102 |
+
|
| 103 |
+
### 9. Conclusion
|
| 104 |
+
|
| 105 |
+
This implementation demonstrates that a relatively small Transformer encoder can perform effective sentiment analysis on the IMDB dataset using only word-level tokenization and simple preprocessing. Through self-attention and positional encoding, the model captures long-range dependencies and focuses on sentiment-relevant parts of each review, achieving strong validation performance compared with simple baselines.
|
| 106 |
+
|
| 107 |
+
From this assignment, the key takeaways are that Transformer-based models are flexible and powerful for text classification tasks and that even a compact architecture benefits from core Transformer ideas such as multi-head self-attention, residual connections, and normalization. At the same time, overall performance is sensitive to choices in tokenization, sequence length, and model capacity, highlighting the importance of careful design and experimentation when applying Transformers to NLP problems.
|
| 108 |
+
|
| 109 |
+
### 10. Reproducibility, Code Organization, and Key Implementation Details
|
| 110 |
+
|
| 111 |
+
The project is organized around core scripts and a directory for saved artifacts:
|
| 112 |
+
|
| 113 |
+
- `assignment_text/code/c1.py`: End-to-end training/evaluation script with multi-size experiments (`small`, `medium`, `large`). It saves per-size checkpoints and a markdown experiment report.
|
| 114 |
+
- `assignment_text/code/c1_analysis.py`: Evaluation and qualitative analysis script for loading a trained checkpoint and inspecting misclassifications. It supports `--model_size` to select `small`, `medium`, or `large`.
|
| 115 |
+
- `assignment_text/saved_model/transformer_imdb_small.pt`, `assignment_text/saved_model/transformer_imdb_medium.pt`, `assignment_text/saved_model/transformer_imdb_large.pt`: Size-specific checkpoints.
|
| 116 |
+
- `assignment_text/saved_model/transformer_imdb_experiment_report.md`: Experiment summary report generated after training.
|
| 117 |
+
- `assignment_text/saved_model/transformer_imdb.pt`: Compatibility summary file containing best-model metadata and all results.
|
| 118 |
+
|
| 119 |
+
Within `c1.py`, the data preprocessing pipeline is implemented by `tokenize`, `build_vocab`, and `preprocess_data`. The `tokenize` function normalizes and splits raw review text into lowercase word tokens, `build_vocab` constructs a frequency-based mapping from tokens to integer indices using only the training texts (reserving indices 0 and 1 for padding and unknown tokens), and `preprocess_data` converts each review into a fixed-length sequence of token IDs via truncation and padding. The `IMDBDataset` class wraps these sequences and labels into a PyTorch Dataset, and `DataLoader` objects provide mini-batches for training and validation.
|
| 120 |
+
|
| 121 |
+
The model components are built in a modular fashion. `PositionalEncoding` implements sinusoidal position encodings that are added to token embeddings, `MultiHeadAttention` performs scaled dot-product self-attention over the sequence across multiple heads, and `TransformerEncoderBlock` combines multi-head attention, a position-wise feed-forward network, residual connections, and layer normalization into a single encoder layer. The `TransformerClassifier` class composes an embedding layer, positional encoding, a stack of encoder blocks, global average pooling over the sequence dimension, and a final linear layer that outputs logits for the two sentiment classes.
|
| 122 |
+
|
| 123 |
+
The training and evaluation logic is encapsulated in `train_model` and `evaluate_model`. `train_model` iterates over the training DataLoader for a specified number of epochs, computes the cross-entropy loss, performs backpropagation and optimizer updates (Adam with an initial learning rate of 0.001), applies a StepLR scheduler that halves the learning rate every two epochs, and reports validation accuracy, precision, recall, and F1-score at the end of each epoch. `evaluate_model` runs the model in evaluation mode on a given DataLoader and aggregates predictions and labels to compute the same metrics. The function `load_imdb_texts` serves as a thin wrapper around the HuggingFace `datasets.load_dataset` API and clearly documents that the underlying data originates from the Large Movie Review Dataset (Maas et al., 2011).
|
| 124 |
+
|
| 125 |
+
To reproduce the experiments, the main dependencies are PyTorch, scikit-learn, NumPy, and the HuggingFace `datasets` library. The script is designed to use a CUDA GPU if available, otherwise it falls back to CPU.
|
| 126 |
+
|
| 127 |
+
Global hyperparameters are defined near the end of `c1.py`:
|
| 128 |
+
|
| 129 |
+
- `MAX_VOCAB = 5000` (maximum vocabulary size)
|
| 130 |
+
- `MAX_LEN = 64` (maximum sequence length in tokens)
|
| 131 |
+
- `BATCH_SIZE = 32` (mini-batch size)
|
| 132 |
+
- `EPOCHS = 5` (number of training epochs)
|
| 133 |
+
- `LR = 0.001` (initial learning rate for Adam)
|
| 134 |
+
|
| 135 |
+
Model-size-specific hyperparameters are defined in `MODEL_SIZES`:
|
| 136 |
+
|
| 137 |
+
- `small`: `d_model=64`, `num_heads=4`, `num_layers=1`, `d_ff=128`
|
| 138 |
+
- `medium`: `d_model=128`, `num_heads=8`, `num_layers=2`, `d_ff=256`
|
| 139 |
+
- `large`: `d_model=256`, `num_heads=8`, `num_layers=4`, `d_ff=512`
|
| 140 |
+
|
| 141 |
+
Running `c1.py` end-to-end will (1) download and preprocess the IMDB training split, (2) train all three model sizes on the training portion, (3) evaluate each model on the validation portion, and (4) save checkpoints plus an experiment report to the `assignment_text/saved_model` directory. Running `c1_analysis.py` with `--model_size` then allows targeted evaluation and qualitative error analysis for the selected size. With these components documented, the experiment is fully reproducible and can be extended for further study.
|
| 142 |
+
|
assignment_llm_1/assignment_text/documentation/error_analysis.json
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"example_id": 1,
|
| 4 |
+
"true_label": 0,
|
| 5 |
+
"predicted_label": 1,
|
| 6 |
+
"confidence_neg": 0.13079692423343658,
|
| 7 |
+
"confidence_pos": 0.8692030906677246,
|
| 8 |
+
"text": "Worth the entertainment value of a rental, especially if you like action movies. This one features the usual car chases, fights with the great Van Damme kick style, shooting battles with the 40 shell load shotgun, and even terrorist style bombs. All of this is entertaining and competently handled but there is nothing that really blows you away if you've seen your share before.<br /><br />The plot is made interesting by the inclusion of a rabbit, which is clever but hardly profound. Many of the characters are heavily stereotyped -- the angry veterans, the terrified illegal aliens, the crooked cops, the indifferent feds, the bitchy tough lady station head, the crooked politician, the fat federale who looks like he was typecast as the Mexican in a Hollywood movie from the 1940s. All passably acted but again nothing special.<br /><br />I thought the main villains were pretty well done and fairly well acted. By the end of the movie you certainly knew who the good guys were and weren't. There was an emotional lift as the really bad ones got their just deserts. Very simplistic, but then you weren't expecting Hamlet, right? The only thing I found really annoying was the constant cuts to VDs daughter during the last fight scene.<br /><br />Not bad. Not good. Passable 4.",
|
| 9 |
+
"explanation": "The model likely struggled with this review due to its mixed sentiment, where the reviewer acknowledges some entertainment value while simultaneously critiquing the film's lack of originality and reliance on stereotypes. Phrases like \"worth the entertainment value of a rental\" may have been interpreted positively, overshadowing the underlying negative sentiment expressed through descriptors like \"nothing special\" and \"hardly profound.\" Additionally, the nuanced critique of character portrayals and plot elements may have led to misclassification, as the model may not have effectively captured the sarcasm or the overall negative tone."
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"example_id": 2,
|
| 13 |
+
"true_label": 0,
|
| 14 |
+
"predicted_label": 1,
|
| 15 |
+
"confidence_neg": 0.19945980608463287,
|
| 16 |
+
"confidence_pos": 0.8005402088165283,
|
| 17 |
+
"text": "its a totally average film with a few semi-alright action sequences that make the plot seem a little better and remind the viewer of the classic van dam films. parts of the plot don't make sense and seem to be added in to use up time. the end plot is that of a very basic type that doesn't leave the viewer guessing and any twists are obvious from the beginning. the end scene with the flask backs don't make sense as they are added in and seem to have little relevance to the history of van dam's character. not really worth watching again, bit disappointed in the end production, even though it is apparent it was shot on a low budget certain shots and sections in the film are of poor directed quality",
|
| 18 |
+
"explanation": "The model likely struggled with this review due to the presence of mixed sentiment and subtle linguistic cues that indicate negativity, such as phrases like \"totally average,\" \"not really worth watching again,\" and \"bit disappointed.\" Additionally, the use of qualifiers like \"semi-alright\" and the overall tone may have led the model to misinterpret the review as more positive than intended, as it could have focused on the mention of \"action sequences\" without fully grasping the critical context surrounding them."
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"example_id": 3,
|
| 22 |
+
"true_label": 0,
|
| 23 |
+
"predicted_label": 1,
|
| 24 |
+
"confidence_neg": 0.07963180541992188,
|
| 25 |
+
"confidence_pos": 0.9203682541847229,
|
| 26 |
+
"text": "First off let me say, If you haven't enjoyed a Van Damme movie since bloodsport, you probably will not like this movie. Most of these movies may not have the best plots or best actors but I enjoy these kinds of movies for what they are. This movie is much better than any of the movies the other action guys (Segal and Dolph) have thought about putting out the past few years. Van Damme is good in the movie, the movie is only worth watching to Van Damme fans. It is not as good as Wake of Death (which i highly recommend to anyone of likes Van Damme) or In hell but, in my opinion it's worth watching. It has the same type of feel to it as Nowhere to Run. Good fun stuff!",
|
| 27 |
+
"explanation": "The model likely struggled with this review due to its mixed sentiment and nuanced language. While the reviewer expresses enjoyment of Van Damme's movies, they also imply that the film may not appeal to a broader audience, which can create confusion. Additionally, phrases like \"not as good as\" and \"only worth watching to Van Damme fans\" introduce a negative sentiment that may have been overshadowed by the overall positive tone, leading to the misclassification."
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"example_id": 4,
|
| 31 |
+
"true_label": 0,
|
| 32 |
+
"predicted_label": 1,
|
| 33 |
+
"confidence_neg": 0.04191816598176956,
|
| 34 |
+
"confidence_pos": 0.958081841468811,
|
| 35 |
+
"text": "Isaac Florentine has made some of the best western Martial Arts action movies ever produced. In particular US Seals 2, Cold Harvest, Special Forces and Undisputed 2 are all action classics. You can tell Isaac has a real passion for the genre and his films are always eventful, creative and sharp affairs, with some of the best fight sequences an action fan could hope for. In particular he has found a muse with Scott Adkins, as talented an actor and action performer as you could hope for. This is borne out with Special Forces and Undisputed 2, but unfortunately The Shepherd just doesn't live up to their abilities.<br /><br />There is no doubt that JCVD looks better here fight-wise than he has done in years, especially in the fight he has (for pretty much no reason) in a prison cell, and in the final showdown with Scott, but look in his eyes. JCVD seems to be dead inside. There's nothing in his eyes at all. It's like he just doesn't care about anything throughout the whole film. And this is the leading man.<br /><br />There are other dodgy aspects to the film, script-wise and visually, but the main problem is that you are utterly unable to empathise with the hero of the film. A genuine shame as I know we all wanted this film to be as special as it genuinely could have been. There are some good bits, mostly the action scenes themselves. This film had a terrific director and action choreographer, and an awesome opponent for JCVD to face down. This could have been the one to bring the veteran action star back up to scratch in the balls-out action movie stakes.<br /><br />Sincerely a shame that this didn't happen.",
|
| 36 |
+
"explanation": "The model likely struggled with this review due to its mixed sentiment and nuanced language. While the reviewer praises Isaac Florentine's work and highlights positive aspects of the films, the overall sentiment shifts negatively when discussing JCVD's performance, particularly with phrases like \"dead inside\" and \"doesn't care.\" This complexity, combined with the presence of both positive and negative keywords, may have led the model to misinterpret the overall sentiment as positive."
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"example_id": 5,
|
| 40 |
+
"true_label": 0,
|
| 41 |
+
"predicted_label": 1,
|
| 42 |
+
"confidence_neg": 0.45393702387809753,
|
| 43 |
+
"confidence_pos": 0.5460630059242249,
|
| 44 |
+
"text": "A group of heirs to a mysterious old mansion find out that they have to live in it as part of a clause in the will or be disinherited, but they soon find out of its history of everybody whom had lived there before them having either died in weird accidents or having had killed each other.<br /><br />You've seen it all before, and this one is too low-budget and slow paced to be scary, and doesn't have any real surprises in the climax. No special effects or gore to speak of, in fact the only really amusing thing about the whole film is the quality of the English dubbing, which at times is as bad as a cheap martial arts movie.<br /><br />3 out of 10, pretty low in the pecking order of 80's haunted house movies.",
|
| 45 |
+
"explanation": "The machine learning model likely struggled with this review due to the presence of mixed sentiment and subtle sarcasm. Phrases like \"you've seen it all before\" and \"the only really amusing thing\" can be interpreted as positive, while the overall context and explicit negative ratings (e.g., \"3 out of 10\") indicate dissatisfaction. Additionally, the model may have misinterpreted the lack of special effects and the critique of dubbing as neutral or positive elements, leading to an incorrect classification."
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"example_id": 6,
|
| 49 |
+
"true_label": 0,
|
| 50 |
+
"predicted_label": 1,
|
| 51 |
+
"confidence_neg": 0.34184929728507996,
|
| 52 |
+
"confidence_pos": 0.6581506729125977,
|
| 53 |
+
"text": "The Forgotten (AKA: Don't Look In The Basement) is a very cheaply made and very old looking horror movie.<br /><br />The story is very slow and never really reaches anything worth getting excited about.<br /><br />The patients at the asylum are embarrassingly funny especially Sam and the old woman who always quotes an old saying to everyone. (Look out for the bit when she gets close to the camera, tell me you can watch without laughing!).<br /><br />Now the gore is very poor looking, with the blood looking pink in many scenes so it doesn't really deserve its place on the video nasties list!.<br /><br />Overall if you aren't looking for a fantastic horror film and have some time to spare then it's worth a watch.",
|
| 54 |
+
"explanation": "The model likely struggled with this review due to its mixed sentiment and the presence of sarcasm. Phrases like \"embarrassingly funny\" and \"worth a watch\" can be interpreted positively, despite the overall negative tone conveyed by descriptors like \"cheaply made\" and \"very slow.\" Additionally, the use of humor in describing the film's shortcomings may have misled the model into classifying the review as positive."
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"example_id": 7,
|
| 58 |
+
"true_label": 0,
|
| 59 |
+
"predicted_label": 1,
|
| 60 |
+
"confidence_neg": 0.1487887054681778,
|
| 61 |
+
"confidence_pos": 0.8512113094329834,
|
| 62 |
+
"text": "I of course saw the previews for this at the beginning of some other Lion's Gate extravaganza, so of course it was only the best parts and therefore looked intriguing. And it is, to a point. A young college student (Sarah)is finding riddles all over the place and is becoming obsessed with answering them, and in doing so she's unwittingly becoming involved in some game. Now that's fairly intriguing right there but unfortunately it all gets rather muddled and becomes so complicated that the viewer (like myself) will most likely become frustrated. Characters appear with little introduction and you're not really sure who they are or why Sarah knows them or is hanging out with them. All of this has something to do with this woman who tried to drown a young boy years ago and her reason for that was that it's \"all part of the design\". In reality, it's all part of the \"very sketchy script\" and when the film is over you'll find yourself feeling that you've lost about an hour and a half of your life that you want back for more productive uses of your time, like cleaning the bathroom, for instance. 4 out of 10.",
|
| 63 |
+
"explanation": "The model likely struggled with this review due to the presence of mixed sentiment and nuanced language. Phrases like \"fairly intriguing\" and \"best parts\" may have been interpreted positively, while the overall context reveals frustration with the plot's complexity and a \"very sketchy script.\" Additionally, the use of sarcasm and the negative conclusion about losing time could have been overlooked by the model, leading to an incorrect positive classification."
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"example_id": 8,
|
| 67 |
+
"true_label": 0,
|
| 68 |
+
"predicted_label": 1,
|
| 69 |
+
"confidence_neg": 0.32372772693634033,
|
| 70 |
+
"confidence_pos": 0.6762722134590149,
|
| 71 |
+
"text": "Four things intrigued me as to this film - firstly, it stars Carly Pope (of \"Popular\" fame), who is always a pleasure to watch. Secdonly, it features brilliant New Zealand actress Rena Owen. Thirdly, it is filmed in association with the New Zealand Film Commission. Fourthly, a friend recommended it to me. However, I was utterly disappointed. The whole storyline is absurd and complicated, with very little resolution. Pope's acting is fine, but Owen is unfortunately under-used. The other actors and actresses are all okay, but I am unfamiliar with them all. Aside from the nice riddles which are littered throughout the movie (and Pope and Owen), this film isn't very good. So the moral of the story is...don't watch it unless you really want to.",
|
| 72 |
+
"explanation": "The model likely struggled with this review due to the presence of mixed sentiment and subtle sarcasm. While the reviewer mentions positive aspects such as the cast and riddles, the overall tone is negative, particularly in phrases like \"utterly disappointed\" and \"this film isn't very good,\" which may have been overshadowed by the initial positive keywords. Additionally, the complexity of the review's structure may have led the model to misinterpret the sentiment conveyed."
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"example_id": 9,
|
| 76 |
+
"true_label": 0,
|
| 77 |
+
"predicted_label": 1,
|
| 78 |
+
"confidence_neg": 0.35207340121269226,
|
| 79 |
+
"confidence_pos": 0.6479266285896301,
|
| 80 |
+
"text": "<br /><br />Never ever take a film just for its good looking title.<br /><br />Although it all starts well, the film suffers the same imperfections you see in B-films. Its like at a certain moment the writer does not any more how to end the film, so he ends it in a way nobody suspects it thinking this way he is ingenious.<br /><br />A film to be listed on top of the garbage list.<br /><br />",
|
| 81 |
+
"explanation": "The model likely struggled with the review due to the presence of sarcasm and mixed sentiment, particularly in phrases like \"good looking title\" and \"thinking this way he is ingenious,\" which could be misinterpreted as positive. Additionally, the use of phrases like \"top of the garbage list\" may not have been weighted heavily enough by the model, leading to an incorrect positive classification despite the overall negative sentiment expressed in the review."
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"example_id": 10,
|
| 85 |
+
"true_label": 0,
|
| 86 |
+
"predicted_label": 1,
|
| 87 |
+
"confidence_neg": 0.42465338110923767,
|
| 88 |
+
"confidence_pos": 0.5753466486930847,
|
| 89 |
+
"text": "Lowe returns to the nest after, yet another, failed relationship, to find he's been assigned to jury duty. It's in the plans to, somehow, get out of it, when he realizes the defendant is the girl he's had a serious crush on since the first grade.<br /><br />Through living in the past by telling other people about his feelings towards this girl (played by Camp), Lowe remembers those feelings and does everything in his power to clear Camp of attempted murder, while staying away from the real bad guys at the same time, and succeeding in creating a successful film at the same time.<br /><br />I've heard that St Augustine is the oldest city in the US, and I also know it has some ties to Ponce de Leon, so the backdrop is a good place to start. Unfortunately, it's the only thing good about this movie. The local police are inept, the judge is an idiot, and the defense counsel does everything in her power to make herself look like Joanie Cunningham! I don't know whether to blame the director for poor direction, or for just letting the cast put in such a hapless effort.<br /><br />In short, this movie was so boring, I could not even sleep through it! 1 out of 10 stars!",
|
| 90 |
+
"explanation": "The model likely struggled with this review due to the presence of mixed sentiment and sarcasm. While the reviewer mentions some positive elements, such as the interesting backdrop of St. Augustine, the overall tone is negative, highlighted by phrases like \"the only thing good about this movie\" and criticisms of the characters and plot. The model may have misinterpreted the positive keywords and failed to recognize the underlying negative sentiment conveyed through the reviewer's frustration and sarcasm."
|
| 91 |
+
}
|
| 92 |
+
]
|
assignment_llm_1/assignment_text/saved_model/Untitled
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
saved_model
|
assignment_llm_1/assignment_text/saved_model/transformer_imdb_experiment_report.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# IMDB Transformer Model-Size Experiment Report
|
| 2 |
+
|
| 3 |
+
- Generated at: `2026-02-12 01:50:15`
|
| 4 |
+
- Device: `cuda`
|
| 5 |
+
- Training samples: `20000`
|
| 6 |
+
- Validation samples: `5000`
|
| 7 |
+
- Max vocab size: `5000`
|
| 8 |
+
- Max sequence length: `64`
|
| 9 |
+
- Batch size: `32`
|
| 10 |
+
- Epochs: `5`
|
| 11 |
+
- Learning rate: `0.001`
|
| 12 |
+
|
| 13 |
+
## Overall Comparison
|
| 14 |
+
|
| 15 |
+
| Model Size | Trainable Params | Accuracy | Precision | Recall | F1 | Checkpoint |
|
| 16 |
+
|---|---:|---:|---:|---:|---:|---|
|
| 17 |
+
| small | 353,602 | 0.7742 | 0.7627 | 0.7960 | 0.7790 | `./saved_model/transformer_imdb_small.pt` |
|
| 18 |
+
| medium | 905,218 | 0.7804 | 0.7674 | 0.8048 | 0.7856 | `./saved_model/transformer_imdb_medium.pt` |
|
| 19 |
+
| large | 3,388,930 | 0.5190 | 0.5098 | 0.9880 | 0.6726 | `./saved_model/transformer_imdb_large.pt` |
|
| 20 |
+
|
| 21 |
+
## Best Model
|
| 22 |
+
|
| 23 |
+
- Best size by validation F1: `medium`
|
| 24 |
+
- Checkpoint: `./saved_model/transformer_imdb_medium.pt`
|
| 25 |
+
- Trainable parameters: `905,218`
|
| 26 |
+
- Metrics:
|
| 27 |
+
- Accuracy: `0.7804`
|
| 28 |
+
- Precision: `0.7674`
|
| 29 |
+
- Recall: `0.8048`
|
| 30 |
+
- F1: `0.7856`
|
| 31 |
+
|
| 32 |
+
## Per-Model Details
|
| 33 |
+
|
| 34 |
+
### Small model
|
| 35 |
+
|
| 36 |
+
- Architecture:
|
| 37 |
+
- `d_model`: `64`
|
| 38 |
+
- `num_heads`: `4`
|
| 39 |
+
- `num_layers`: `1`
|
| 40 |
+
- `d_ff`: `128`
|
| 41 |
+
- Trainable params: `353,602`
|
| 42 |
+
- Checkpoint: `./saved_model/transformer_imdb_small.pt`
|
| 43 |
+
- Validation metrics:
|
| 44 |
+
- Accuracy: `0.7742`
|
| 45 |
+
- Precision: `0.7627`
|
| 46 |
+
- Recall: `0.7960`
|
| 47 |
+
- F1: `0.7790`
|
| 48 |
+
|
| 49 |
+
### Medium model
|
| 50 |
+
|
| 51 |
+
- Architecture:
|
| 52 |
+
- `d_model`: `128`
|
| 53 |
+
- `num_heads`: `8`
|
| 54 |
+
- `num_layers`: `2`
|
| 55 |
+
- `d_ff`: `256`
|
| 56 |
+
- Trainable params: `905,218`
|
| 57 |
+
- Checkpoint: `./saved_model/transformer_imdb_medium.pt`
|
| 58 |
+
- Validation metrics:
|
| 59 |
+
- Accuracy: `0.7804`
|
| 60 |
+
- Precision: `0.7674`
|
| 61 |
+
- Recall: `0.8048`
|
| 62 |
+
- F1: `0.7856`
|
| 63 |
+
|
| 64 |
+
### Large model
|
| 65 |
+
|
| 66 |
+
- Architecture:
|
| 67 |
+
- `d_model`: `256`
|
| 68 |
+
- `num_heads`: `8`
|
| 69 |
+
- `num_layers`: `4`
|
| 70 |
+
- `d_ff`: `512`
|
| 71 |
+
- Trainable params: `3,388,930`
|
| 72 |
+
- Checkpoint: `./saved_model/transformer_imdb_large.pt`
|
| 73 |
+
- Validation metrics:
|
| 74 |
+
- Accuracy: `0.5190`
|
| 75 |
+
- Precision: `0.5098`
|
| 76 |
+
- Recall: `0.9880`
|
| 77 |
+
- F1: `0.6726`
|
assignment_llm_1/data/cifar-10-batches-py/batches.meta
ADDED
|
Binary file (158 Bytes). View file
|
|
|
assignment_llm_1/data/cifar-10-batches-py/readme.html
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
<meta HTTP-EQUIV="REFRESH" content="0; url=http://www.cs.toronto.edu/~kriz/cifar.html">
|
code/RL_model/inference_data/RL_model_inference_v1.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
code/RL_model/inference_data/inference_20260213_002423.jsonl
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"row_index": 0, "data_source": "multiclinsum", "ability": "summarization", "prediction": "{\n \"low_health_literacy\": \"A 35-year-old woman had a lump on her cheek side (parotid gland) that grew over 2 years. At first, it looked like a slow-growing tumor, but later it got bigger and caused her face to droop. Doctors saw a hard, fixed lump that spread into nearby muscles. An MRI showed the lump was growing deep inside and had spread beyond the gland. She had surgery to remove the whole lump and nearby lymph nodes, keeping her facial nerve safe. Path tests confirmed it was a rare cancer called PNET, which often starts in nerves and has a special genetic change in 60% of the cells. After surgery, she got chemotherapy (four rounds of drugs: vincristine, doxorubicin, cyclophosphamide, ifosfamide, and etoposide) and then radiation to the area. Her body handled the treatment well, with only mild side effects. She checks in every 3 months with scans, and so far, no signs of cancer coming back.\"\n}"}
|
| 2 |
+
{"row_index": 1, "data_source": "multiclinsum", "ability": "summarization", "prediction": "{\n \"intermediate_health_literacy\": \"A 35-year-old woman developed a lump in her left parotid gland two years earlier. Initially, a small biopsy showed a non-cancerous mixed salivary gland tumor, and she had surgery that removed the mass completely. Six months later, a new, firm, painful lump appeared near the surgical site, and she began to lose movement on her left side—this is known as facial paralysis. An MRI scan showed a growing mass measuring 42 mm that had spread into nearby muscles and was not clearly defined. After further testing, doctors found it was a rare cancer called extra-skeletal Ewing’s sarcoma, a type of primary peripheral neuroectodermal tumor (PNET). This tumor has a specific genetic change (EWING/PNET-type translocation) found in 60% of the cells. She had surgery to remove the entire left parotid gland and nearby lymph nodes, with care taken to preserve her facial nerve. Following surgery, she received chemotherapy using a combination of drugs—vincristine, doxorubicin, cyclophosphamide, ifosfamide, and etoposide—for four cycles, followed by external radiation therapy. The treatment was well tolerated, with only mild side effects like mouth sores and skin irritation. At 10 months after treatment, there was no sign of the cancer coming back locally or elsewhere in her body.\"\n}"}
|
code/RL_model/inference_data/vllm_inference_20260213_003845.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
code/RL_model/inference_data/vllm_inference_20260213_003845_meta.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_path_for_tokenizer": "/home/mshahidul/readctrl/code/RL_model/converted_model/v1",
|
| 3 |
+
"dataset_path": "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/test.parquet",
|
| 4 |
+
"base_url": "http://127.0.0.1:8001/v1",
|
| 5 |
+
"served_model_name": "inference",
|
| 6 |
+
"batch_size": 8,
|
| 7 |
+
"num_samples": 510,
|
| 8 |
+
"output_jsonl": "/home/mshahidul/readctrl/code/RL_model/inference_data/vllm_inference_20260213_003845.jsonl",
|
| 9 |
+
"output_parquet": "/home/mshahidul/readctrl/code/RL_model/inference_data/vllm_inference_20260213_003845.parquet"
|
| 10 |
+
}
|
code/RL_model/inference_data/vllm_inference_20260213_165923.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
code/RL_model/inference_data/vllm_inference_20260213_170937_meta.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_path_for_tokenizer": "/home/mshahidul/readctrl/code/RL_model/models/converted_model/v1",
|
| 3 |
+
"dataset_path": "/home/mshahidul/readctrl/code/text_classifier/data/verified_combined_0-80_clean200.json",
|
| 4 |
+
"base_url": "http://127.0.0.1:8001/v1",
|
| 5 |
+
"served_model_name": "inference",
|
| 6 |
+
"batch_size": 8,
|
| 7 |
+
"num_samples": 200,
|
| 8 |
+
"output_jsonl": "/home/mshahidul/readctrl/code/RL_model/inference_data/vllm_inference_20260213_170937.jsonl",
|
| 9 |
+
"output_parquet": "/home/mshahidul/readctrl/code/RL_model/inference_data/vllm_inference_20260213_170937.parquet"
|
| 10 |
+
}
|
code/RL_model/inference_data/vllm_inference_qwen-qwen3-4b-instruct-2507_20260213_173334_meta.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_path_for_tokenizer": "Qwen/Qwen3-4B-Instruct-2507",
|
| 3 |
+
"dataset_path": "/home/mshahidul/readctrl/code/text_classifier/data/verified_combined_0-80_clean200.json",
|
| 4 |
+
"base_url": "http://127.0.0.1:8001/v1",
|
| 5 |
+
"served_model_name": "inference",
|
| 6 |
+
"batch_size": 8,
|
| 7 |
+
"num_samples": 200,
|
| 8 |
+
"output_jsonl": "/home/mshahidul/readctrl/code/RL_model/inference_data/vllm_inference_qwen-qwen3-4b-instruct-2507_20260213_173334.jsonl",
|
| 9 |
+
"output_parquet": "/home/mshahidul/readctrl/code/RL_model/inference_data/vllm_inference_qwen-qwen3-4b-instruct-2507_20260213_173334.parquet"
|
| 10 |
+
}
|
code/RL_model/unsloth_rl/RL_code.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 3 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
| 4 |
+
from unsloth import FastLanguageModel
|
| 5 |
+
import torch
|
| 6 |
+
from health_classifier import classifier
|
| 7 |
+
max_seq_length = 8192
|
| 8 |
+
|
| 9 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 10 |
+
model_name = "/home/mshahidul/readctrl_model/RL_model/readability_sft_lora_model",
|
| 11 |
+
max_seq_length = max_seq_length,
|
| 12 |
+
load_in_4bit = False, # Set to False if you have enough VRAM
|
| 13 |
+
fast_inference = False,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# Simply enable gradient checkpointing and prepare for training
|
| 17 |
+
model = FastLanguageModel.for_training(model)
|
| 18 |
+
|
| 19 |
+
# /home/mshahidul/readctrl/data/extracting_subclaim/extracted_subclaims_multiclinsum_test_en_full.json
|
| 20 |
+
with open("/home/mshahidul/readctrl/data/extracting_subclaim/extracted_subclaims_multiclinsum_test_en_full.json", "r") as f:
|
| 21 |
+
import json
|
| 22 |
+
data = json.load(f)
|
| 23 |
+
from datasets import Dataset
|
| 24 |
+
dataset = Dataset.from_list(data)
|
| 25 |
+
with open('/home/mshahidul/readctrl/code/RL_model/prompt', 'r') as f:
|
| 26 |
+
prompt_template = f.read()
|
| 27 |
+
dataset = dataset.map(lambda x: {
|
| 28 |
+
"prompt" : [
|
| 29 |
+
{"role": "system", "content": prompt_template},
|
| 30 |
+
{"role": "user", "content": f'''
|
| 31 |
+
- Input Language: English
|
| 32 |
+
- Gold Summary (the anchor reference summary): {x['summary']}
|
| 33 |
+
- Source Text (detailed content): {x['fulltext']}
|
| 34 |
+
'''},
|
| 35 |
+
],
|
| 36 |
+
"answer": {
|
| 37 |
+
"fulltext_subclaims": x['fulltext_subclaims'],
|
| 38 |
+
"summary_subclaims": x['summary_subclaims'],
|
| 39 |
+
},
|
| 40 |
+
})
|
| 41 |
+
import requests
|
| 42 |
+
import json
|
| 43 |
+
import re
|
| 44 |
+
|
| 45 |
+
from claim_verifier import MedicalClaimVerifier
|
| 46 |
+
|
| 47 |
+
verifier = MedicalClaimVerifier()
|
| 48 |
+
|
| 49 |
+
def claim_reward_func(prompts, completions, answer, **kwargs):
|
| 50 |
+
# import ipdb; ipdb.set_trace()
|
| 51 |
+
"""
|
| 52 |
+
GRPO reward function.
|
| 53 |
+
Expects 'summary_subclaims' and 'fulltext_subclaims' to be in the dataset.
|
| 54 |
+
"""
|
| 55 |
+
rewards = []
|
| 56 |
+
# We loop through the group of completions
|
| 57 |
+
for i in range(len(completions)):
|
| 58 |
+
reward = verifier.get_reward_score(
|
| 59 |
+
completions[i],
|
| 60 |
+
answer[i]["summary_subclaims"],
|
| 61 |
+
answer[i]["fulltext_subclaims"]
|
| 62 |
+
)
|
| 63 |
+
rewards.append(reward)
|
| 64 |
+
return rewards
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# def format_reward_func(completions, **kwargs):
|
| 68 |
+
# required_keys = ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"]
|
| 69 |
+
# scores = []
|
| 70 |
+
# for completion in completions:
|
| 71 |
+
# try:
|
| 72 |
+
# match = re.search(r"<SOLUTION>(.*?)</SOLUTION>", completion, re.DOTALL)
|
| 73 |
+
# content = match.group(1) if match else completion
|
| 74 |
+
# data = json.loads(content)
|
| 75 |
+
# if all(k in data for k in required_keys):
|
| 76 |
+
# scores.append(2.0)
|
| 77 |
+
# else:
|
| 78 |
+
# scores.append(-1.0)
|
| 79 |
+
# except:
|
| 80 |
+
# scores.append(-2.0)
|
| 81 |
+
# return scores
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
import json
|
| 85 |
+
|
| 86 |
+
def literacy_classifier_reward_func(completions, **kwargs):
|
| 87 |
+
scores = []
|
| 88 |
+
for completion in completions:
|
| 89 |
+
try:
|
| 90 |
+
# 1. Clean up potential Markdown formatting
|
| 91 |
+
cleaned_content = completion[0]['content'].strip()
|
| 92 |
+
if cleaned_content.startswith("```"):
|
| 93 |
+
# Removes leading ```json or ``` and trailing ```
|
| 94 |
+
cleaned_content = cleaned_content.split("```")[1]
|
| 95 |
+
if cleaned_content.startswith("json"):
|
| 96 |
+
cleaned_content = cleaned_content[4:]
|
| 97 |
+
|
| 98 |
+
# 2. Parse the JSON
|
| 99 |
+
data = json.loads(cleaned_content.strip())
|
| 100 |
+
|
| 101 |
+
alignment_score = 0.0
|
| 102 |
+
target_labels = ["low", "intermediate", "proficient"]
|
| 103 |
+
|
| 104 |
+
for label in target_labels:
|
| 105 |
+
key = f"{label}_health_literacy"
|
| 106 |
+
text_to_test = data.get(key, "")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if text_to_test:
|
| 110 |
+
# Run the DSPy classifier
|
| 111 |
+
result = classifier(summary_text=text_to_test)
|
| 112 |
+
predicted = result.label # Expected format: "low_health_literacy"
|
| 113 |
+
# import ipdb; ipdb.set_trace()
|
| 114 |
+
|
| 115 |
+
if predicted == key:
|
| 116 |
+
alignment_score += 1.0
|
| 117 |
+
else:
|
| 118 |
+
# Soft penalty for misclassification
|
| 119 |
+
alignment_score -= 0.5
|
| 120 |
+
else:
|
| 121 |
+
# Penalty if a specific literacy level is missing from the JSON
|
| 122 |
+
alignment_score -= 0.3
|
| 123 |
+
|
| 124 |
+
scores.append(alignment_score)
|
| 125 |
+
|
| 126 |
+
except (json.JSONDecodeError, Exception):
|
| 127 |
+
# Significant penalty for malformed JSON or failed processing
|
| 128 |
+
scores.append(-1.0)
|
| 129 |
+
|
| 130 |
+
return scores
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 134 |
+
|
| 135 |
+
training_args = GRPOConfig(
|
| 136 |
+
learning_rate = 5e-6,
|
| 137 |
+
lr_scheduler_type = "cosine",
|
| 138 |
+
weight_decay = 0.1,
|
| 139 |
+
max_prompt_length = 8192,
|
| 140 |
+
max_completion_length = 4096,
|
| 141 |
+
# num_of_epochs = 10,
|
| 142 |
+
num_generations = 4, # GRPO group size
|
| 143 |
+
per_device_train_batch_size = 4,
|
| 144 |
+
gradient_accumulation_steps = 4,
|
| 145 |
+
max_steps = 500,
|
| 146 |
+
bf16 = True,
|
| 147 |
+
output_dir = "medical_grpo_outputs",
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
trainer = GRPOTrainer(
|
| 151 |
+
model = model,
|
| 152 |
+
reward_funcs = [
|
| 153 |
+
claim_reward_func,
|
| 154 |
+
# format_reward_func,
|
| 155 |
+
literacy_classifier_reward_func
|
| 156 |
+
],
|
| 157 |
+
args = training_args,
|
| 158 |
+
train_dataset = dataset, # Use the same dataset from your SFT prep
|
| 159 |
+
tokenizer = tokenizer,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
trainer.train()
|
| 163 |
+
|
| 164 |
+
model.save_pretrained("/home/mshahidul/readctrl_model/readability_GRPO_model_v1")
|
| 165 |
+
tokenizer.save_pretrained("/home/mshahidul/readctrl_model/readability_GRPO_model_v1")
|
code/RL_model/unsloth_rl/RL_training.ipynb
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "8a790cb6",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"from unsloth import FastLanguageModel\n",
|
| 11 |
+
"import torch\n",
|
| 12 |
+
"max_seq_length = 2048 # Can increase for longer reasoning traces\n",
|
| 13 |
+
"lora_rank = 32 # Larger rank = smarter, but slower\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 16 |
+
" model_name = \"unsloth/Qwen3-4B-Base\",\n",
|
| 17 |
+
" max_seq_length = max_seq_length,\n",
|
| 18 |
+
" load_in_4bit = False, # False for LoRA 16bit\n",
|
| 19 |
+
" fast_inference = True, # Enable vLLM fast inference\n",
|
| 20 |
+
" max_lora_rank = lora_rank,\n",
|
| 21 |
+
" gpu_memory_utilization = 0.9, # Reduce if out of memory\n",
|
| 22 |
+
")\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 25 |
+
" model,\n",
|
| 26 |
+
" r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
|
| 27 |
+
" target_modules = [\n",
|
| 28 |
+
" \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 29 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\",\n",
|
| 30 |
+
" ],\n",
|
| 31 |
+
" lora_alpha = lora_rank*2, # *2 speeds up training\n",
|
| 32 |
+
" use_gradient_checkpointing = \"unsloth\", # Reduces memory usage\n",
|
| 33 |
+
" random_state = 3407,\n",
|
| 34 |
+
")"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": null,
|
| 40 |
+
"id": "ba056efa",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [],
|
| 43 |
+
"source": [
|
| 44 |
+
"# /home/mshahidul/readctrl/data/extracting_subclaim/extracted_subclaims_multiclinsum_test_en_full.json\n",
|
| 45 |
+
"with open('/home/mshahidul/readctrl/data/extracting_subclaim/extracted_subclaims_multiclinsum_test_en_full.json', 'r') as f:\n",
|
| 46 |
+
" synthetic_data_with_gs_summary_en = json.load(f)\n",
|
| 47 |
+
"from datasets import Dataset\n",
|
| 48 |
+
"dataset = Dataset.from_list(synthetic_data_with_gs_summary_en)"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": null,
|
| 54 |
+
"id": "fa285d3f",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [],
|
| 57 |
+
"source": [
|
| 58 |
+
"dataset"
|
| 59 |
+
]
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"cell_type": "code",
|
| 63 |
+
"execution_count": null,
|
| 64 |
+
"id": "ad059247",
|
| 65 |
+
"metadata": {},
|
| 66 |
+
"outputs": [],
|
| 67 |
+
"source": [
|
| 68 |
+
"# /home/mshahidul/readctrl/code/RL_model/prompt\n",
|
| 69 |
+
"with open('/home/mshahidul/readctrl/code/RL_model/prompt', 'r') as f:\n",
|
| 70 |
+
" prompt_template = f.read()"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "code",
|
| 75 |
+
"execution_count": null,
|
| 76 |
+
"id": "f74cbfda",
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"dataset = dataset.map(lambda x: {\n",
|
| 81 |
+
" \"prompt\" : [\n",
|
| 82 |
+
" {\"role\": \"system\", \"content\": prompt_template},\n",
|
| 83 |
+
" {\"role\": \"user\", \"content\": f'''\n",
|
| 84 |
+
"- Input Language: English\n",
|
| 85 |
+
"- Gold Summary (the anchor reference summary): {x['summary']}\n",
|
| 86 |
+
"- Source Text (detailed content): {x['fulltext']}\n",
|
| 87 |
+
"'''},\n",
|
| 88 |
+
" ],\n",
|
| 89 |
+
" \"answer\": {\n",
|
| 90 |
+
" \"fulltext_subclaims\": x['fulltext_subclaims'],\n",
|
| 91 |
+
" \"summary_subclaims\": x['summary_subclaims'],\n",
|
| 92 |
+
" },\n",
|
| 93 |
+
"})"
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"cell_type": "code",
|
| 98 |
+
"execution_count": null,
|
| 99 |
+
"id": "0dd615f4",
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"outputs": [],
|
| 102 |
+
"source": [
|
| 103 |
+
"# /home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_en_20_67.json\n",
|
| 104 |
+
"import json\n",
|
| 105 |
+
"with open('/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_en_0_80_full.json', 'r') as f:\n",
|
| 106 |
+
" synthetic_data_diff_labels_en = json.load(f)\n",
|
| 107 |
+
"full_data=[]\n",
|
| 108 |
+
"# print((synthetic_data_diff_labels_en)[0].keys())\n",
|
| 109 |
+
"for item in synthetic_data_diff_labels_en:\n",
|
| 110 |
+
" texts=item['diff_label_texts']\n",
|
| 111 |
+
" for label in texts:\n",
|
| 112 |
+
" full_data.append({\n",
|
| 113 |
+
" \"index\": item['index'],\n",
|
| 114 |
+
" 'label': label,\n",
|
| 115 |
+
" \"original_text\": item['fulltext'],\n",
|
| 116 |
+
" \"generated_summary\": texts[label]\n",
|
| 117 |
+
" })\n"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "code",
|
| 122 |
+
"execution_count": null,
|
| 123 |
+
"id": "3ba2a6cf",
|
| 124 |
+
"metadata": {},
|
| 125 |
+
"outputs": [],
|
| 126 |
+
"source": [
|
| 127 |
+
"with open('/home/mshahidul/readctrl/data/data_annotator_data/syn_data_diff_labels_en_0_80.json', 'w') as f:\n",
|
| 128 |
+
" json.dump(full_data, f, indent=4)"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "code",
|
| 133 |
+
"execution_count": null,
|
| 134 |
+
"id": "7cddc461",
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"outputs": [],
|
| 137 |
+
"source": [
|
| 138 |
+
"# /home/mshahidul/readctrl/data/translated_data/translation_english2bangla_v1.json\n",
|
| 139 |
+
"import json\n",
|
| 140 |
+
"with open('/home/mshahidul/readctrl/data/testing_data_gs/multiclinsum_gs_train_en.json', 'r', encoding='utf-8') as f:\n",
|
| 141 |
+
" dataset = json.load(f)\n",
|
| 142 |
+
"print(dataset[0].keys())"
|
| 143 |
+
]
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"cell_type": "code",
|
| 147 |
+
"execution_count": 27,
|
| 148 |
+
"id": "2b3f2a96",
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"outputs": [
|
| 151 |
+
{
|
| 152 |
+
"name": "stdout",
|
| 153 |
+
"output_type": "stream",
|
| 154 |
+
"text": [
|
| 155 |
+
"0_low_health_literacy\n",
|
| 156 |
+
"0_intermediate_health_literacy\n",
|
| 157 |
+
"0_proficient_health_literacy\n",
|
| 158 |
+
"1_low_health_literacy\n",
|
| 159 |
+
"1_intermediate_health_literacy\n",
|
| 160 |
+
"1_proficient_health_literacy\n",
|
| 161 |
+
"2_low_health_literacy\n",
|
| 162 |
+
"2_intermediate_health_literacy\n",
|
| 163 |
+
"2_proficient_health_literacy\n",
|
| 164 |
+
"3_low_health_literacy\n",
|
| 165 |
+
"3_intermediate_health_literacy\n",
|
| 166 |
+
"3_proficient_health_literacy\n",
|
| 167 |
+
"4_low_health_literacy\n",
|
| 168 |
+
"4_intermediate_health_literacy\n",
|
| 169 |
+
"4_proficient_health_literacy\n",
|
| 170 |
+
"5_low_health_literacy\n",
|
| 171 |
+
"5_intermediate_health_literacy\n",
|
| 172 |
+
"5_proficient_health_literacy\n",
|
| 173 |
+
"6_low_health_literacy\n",
|
| 174 |
+
"6_intermediate_health_literacy\n",
|
| 175 |
+
"6_proficient_health_literacy\n",
|
| 176 |
+
"7_low_health_literacy\n",
|
| 177 |
+
"7_intermediate_health_literacy\n",
|
| 178 |
+
"7_proficient_health_literacy\n",
|
| 179 |
+
"8_low_health_literacy\n",
|
| 180 |
+
"8_intermediate_health_literacy\n",
|
| 181 |
+
"8_proficient_health_literacy\n",
|
| 182 |
+
"9_low_health_literacy\n",
|
| 183 |
+
"9_intermediate_health_literacy\n",
|
| 184 |
+
"9_proficient_health_literacy\n",
|
| 185 |
+
"10_low_health_literacy\n",
|
| 186 |
+
"10_intermediate_health_literacy\n",
|
| 187 |
+
"10_proficient_health_literacy\n",
|
| 188 |
+
"11_low_health_literacy\n",
|
| 189 |
+
"11_intermediate_health_literacy\n",
|
| 190 |
+
"11_proficient_health_literacy\n",
|
| 191 |
+
"12_low_health_literacy\n",
|
| 192 |
+
"12_intermediate_health_literacy\n",
|
| 193 |
+
"12_proficient_health_literacy\n",
|
| 194 |
+
"13_low_health_literacy\n",
|
| 195 |
+
"13_intermediate_health_literacy\n",
|
| 196 |
+
"13_proficient_health_literacy\n",
|
| 197 |
+
"14_low_health_literacy\n",
|
| 198 |
+
"14_intermediate_health_literacy\n",
|
| 199 |
+
"14_proficient_health_literacy\n",
|
| 200 |
+
"15_low_health_literacy\n",
|
| 201 |
+
"15_intermediate_health_literacy\n",
|
| 202 |
+
"15_proficient_health_literacy\n",
|
| 203 |
+
"16_low_health_literacy\n",
|
| 204 |
+
"16_intermediate_health_literacy\n",
|
| 205 |
+
"16_proficient_health_literacy\n",
|
| 206 |
+
"17_low_health_literacy\n",
|
| 207 |
+
"17_intermediate_health_literacy\n",
|
| 208 |
+
"17_proficient_health_literacy\n",
|
| 209 |
+
"18_low_health_literacy\n",
|
| 210 |
+
"18_intermediate_health_literacy\n",
|
| 211 |
+
"18_proficient_health_literacy\n",
|
| 212 |
+
"19_low_health_literacy\n",
|
| 213 |
+
"19_intermediate_health_literacy\n",
|
| 214 |
+
"19_proficient_health_literacy\n",
|
| 215 |
+
"20_low_health_literacy\n",
|
| 216 |
+
"20_intermediate_health_literacy\n",
|
| 217 |
+
"20_proficient_health_literacy\n",
|
| 218 |
+
"21_low_health_literacy\n",
|
| 219 |
+
"21_intermediate_health_literacy\n",
|
| 220 |
+
"21_proficient_health_literacy\n",
|
| 221 |
+
"22_low_health_literacy\n",
|
| 222 |
+
"22_intermediate_health_literacy\n",
|
| 223 |
+
"22_proficient_health_literacy\n",
|
| 224 |
+
"23_low_health_literacy\n",
|
| 225 |
+
"23_intermediate_health_literacy\n",
|
| 226 |
+
"23_proficient_health_literacy\n",
|
| 227 |
+
"24_low_health_literacy\n",
|
| 228 |
+
"24_intermediate_health_literacy\n",
|
| 229 |
+
"24_proficient_health_literacy\n",
|
| 230 |
+
"25_low_health_literacy\n",
|
| 231 |
+
"25_intermediate_health_literacy\n",
|
| 232 |
+
"25_proficient_health_literacy\n",
|
| 233 |
+
"26_low_health_literacy\n",
|
| 234 |
+
"26_intermediate_health_literacy\n",
|
| 235 |
+
"26_proficient_health_literacy\n",
|
| 236 |
+
"27_low_health_literacy\n",
|
| 237 |
+
"27_intermediate_health_literacy\n",
|
| 238 |
+
"27_proficient_health_literacy\n",
|
| 239 |
+
"28_low_health_literacy\n",
|
| 240 |
+
"28_intermediate_health_literacy\n",
|
| 241 |
+
"28_proficient_health_literacy\n",
|
| 242 |
+
"29_low_health_literacy\n",
|
| 243 |
+
"29_intermediate_health_literacy\n",
|
| 244 |
+
"29_proficient_health_literacy\n",
|
| 245 |
+
"30_low_health_literacy\n",
|
| 246 |
+
"30_intermediate_health_literacy\n",
|
| 247 |
+
"30_proficient_health_literacy\n",
|
| 248 |
+
"31_low_health_literacy\n",
|
| 249 |
+
"31_intermediate_health_literacy\n",
|
| 250 |
+
"31_proficient_health_literacy\n",
|
| 251 |
+
"32_low_health_literacy\n",
|
| 252 |
+
"32_intermediate_health_literacy\n",
|
| 253 |
+
"32_proficient_health_literacy\n",
|
| 254 |
+
"33_low_health_literacy\n",
|
| 255 |
+
"33_intermediate_health_literacy\n",
|
| 256 |
+
"33_proficient_health_literacy\n",
|
| 257 |
+
"34_low_health_literacy\n",
|
| 258 |
+
"34_intermediate_health_literacy\n",
|
| 259 |
+
"34_proficient_health_literacy\n",
|
| 260 |
+
"35_low_health_literacy\n",
|
| 261 |
+
"35_intermediate_health_literacy\n",
|
| 262 |
+
"35_proficient_health_literacy\n",
|
| 263 |
+
"36_low_health_literacy\n",
|
| 264 |
+
"36_intermediate_health_literacy\n",
|
| 265 |
+
"36_proficient_health_literacy\n",
|
| 266 |
+
"37_low_health_literacy\n",
|
| 267 |
+
"37_intermediate_health_literacy\n",
|
| 268 |
+
"37_proficient_health_literacy\n",
|
| 269 |
+
"38_low_health_literacy\n",
|
| 270 |
+
"38_intermediate_health_literacy\n",
|
| 271 |
+
"38_proficient_health_literacy\n",
|
| 272 |
+
"39_low_health_literacy\n",
|
| 273 |
+
"39_intermediate_health_literacy\n",
|
| 274 |
+
"39_proficient_health_literacy\n",
|
| 275 |
+
"40_low_health_literacy\n",
|
| 276 |
+
"40_intermediate_health_literacy\n",
|
| 277 |
+
"40_proficient_health_literacy\n",
|
| 278 |
+
"41_low_health_literacy\n",
|
| 279 |
+
"41_intermediate_health_literacy\n",
|
| 280 |
+
"41_proficient_health_literacy\n",
|
| 281 |
+
"42_low_health_literacy\n",
|
| 282 |
+
"42_intermediate_health_literacy\n",
|
| 283 |
+
"42_proficient_health_literacy\n",
|
| 284 |
+
"43_low_health_literacy\n",
|
| 285 |
+
"43_intermediate_health_literacy\n",
|
| 286 |
+
"43_proficient_health_literacy\n",
|
| 287 |
+
"44_low_health_literacy\n",
|
| 288 |
+
"44_intermediate_health_literacy\n",
|
| 289 |
+
"44_proficient_health_literacy\n",
|
| 290 |
+
"45_low_health_literacy\n",
|
| 291 |
+
"45_intermediate_health_literacy\n",
|
| 292 |
+
"45_proficient_health_literacy\n",
|
| 293 |
+
"46_low_health_literacy\n",
|
| 294 |
+
"46_intermediate_health_literacy\n",
|
| 295 |
+
"46_proficient_health_literacy\n",
|
| 296 |
+
"47_low_health_literacy\n",
|
| 297 |
+
"47_intermediate_health_literacy\n",
|
| 298 |
+
"47_proficient_health_literacy\n",
|
| 299 |
+
"48_low_health_literacy\n",
|
| 300 |
+
"48_intermediate_health_literacy\n",
|
| 301 |
+
"48_proficient_health_literacy\n",
|
| 302 |
+
"49_low_health_literacy\n",
|
| 303 |
+
"49_intermediate_health_literacy\n",
|
| 304 |
+
"49_proficient_health_literacy\n",
|
| 305 |
+
"50_low_health_literacy\n",
|
| 306 |
+
"50_intermediate_health_literacy\n",
|
| 307 |
+
"50_proficient_health_literacy\n",
|
| 308 |
+
"51_low_health_literacy\n",
|
| 309 |
+
"51_intermediate_health_literacy\n",
|
| 310 |
+
"51_proficient_health_literacy\n",
|
| 311 |
+
"52_low_health_literacy\n",
|
| 312 |
+
"52_intermediate_health_literacy\n",
|
| 313 |
+
"52_proficient_health_literacy\n",
|
| 314 |
+
"53_low_health_literacy\n",
|
| 315 |
+
"53_intermediate_health_literacy\n",
|
| 316 |
+
"53_proficient_health_literacy\n",
|
| 317 |
+
"54_low_health_literacy\n",
|
| 318 |
+
"54_intermediate_health_literacy\n",
|
| 319 |
+
"54_proficient_health_literacy\n",
|
| 320 |
+
"55_low_health_literacy\n",
|
| 321 |
+
"55_intermediate_health_literacy\n",
|
| 322 |
+
"55_proficient_health_literacy\n",
|
| 323 |
+
"56_low_health_literacy\n",
|
| 324 |
+
"56_intermediate_health_literacy\n",
|
| 325 |
+
"56_proficient_health_literacy\n",
|
| 326 |
+
"57_low_health_literacy\n",
|
| 327 |
+
"57_intermediate_health_literacy\n",
|
| 328 |
+
"57_proficient_health_literacy\n",
|
| 329 |
+
"58_low_health_literacy\n",
|
| 330 |
+
"58_intermediate_health_literacy\n",
|
| 331 |
+
"58_proficient_health_literacy\n",
|
| 332 |
+
"59_low_health_literacy\n",
|
| 333 |
+
"59_intermediate_health_literacy\n",
|
| 334 |
+
"59_proficient_health_literacy\n",
|
| 335 |
+
"60_low_health_literacy\n",
|
| 336 |
+
"60_intermediate_health_literacy\n",
|
| 337 |
+
"60_proficient_health_literacy\n",
|
| 338 |
+
"61_low_health_literacy\n",
|
| 339 |
+
"61_intermediate_health_literacy\n",
|
| 340 |
+
"61_proficient_health_literacy\n",
|
| 341 |
+
"62_low_health_literacy\n",
|
| 342 |
+
"62_intermediate_health_literacy\n",
|
| 343 |
+
"62_proficient_health_literacy\n",
|
| 344 |
+
"63_low_health_literacy\n",
|
| 345 |
+
"63_intermediate_health_literacy\n",
|
| 346 |
+
"63_proficient_health_literacy\n",
|
| 347 |
+
"64_low_health_literacy\n",
|
| 348 |
+
"64_intermediate_health_literacy\n",
|
| 349 |
+
"64_proficient_health_literacy\n",
|
| 350 |
+
"65_low_health_literacy\n",
|
| 351 |
+
"65_intermediate_health_literacy\n",
|
| 352 |
+
"65_proficient_health_literacy\n",
|
| 353 |
+
"66_low_health_literacy\n",
|
| 354 |
+
"66_intermediate_health_literacy\n",
|
| 355 |
+
"66_proficient_health_literacy\n",
|
| 356 |
+
"67_low_health_literacy\n",
|
| 357 |
+
"67_intermediate_health_literacy\n",
|
| 358 |
+
"67_proficient_health_literacy\n",
|
| 359 |
+
"68_low_health_literacy\n",
|
| 360 |
+
"68_intermediate_health_literacy\n",
|
| 361 |
+
"68_proficient_health_literacy\n",
|
| 362 |
+
"69_low_health_literacy\n",
|
| 363 |
+
"69_intermediate_health_literacy\n",
|
| 364 |
+
"69_proficient_health_literacy\n",
|
| 365 |
+
"70_low_health_literacy\n",
|
| 366 |
+
"70_intermediate_health_literacy\n",
|
| 367 |
+
"70_proficient_health_literacy\n",
|
| 368 |
+
"71_low_health_literacy\n",
|
| 369 |
+
"71_intermediate_health_literacy\n",
|
| 370 |
+
"71_proficient_health_literacy\n",
|
| 371 |
+
"72_low_health_literacy\n",
|
| 372 |
+
"72_intermediate_health_literacy\n",
|
| 373 |
+
"72_proficient_health_literacy\n",
|
| 374 |
+
"73_low_health_literacy\n",
|
| 375 |
+
"73_intermediate_health_literacy\n",
|
| 376 |
+
"73_proficient_health_literacy\n",
|
| 377 |
+
"74_low_health_literacy\n",
|
| 378 |
+
"74_intermediate_health_literacy\n",
|
| 379 |
+
"74_proficient_health_literacy\n",
|
| 380 |
+
"75_low_health_literacy\n",
|
| 381 |
+
"75_intermediate_health_literacy\n",
|
| 382 |
+
"75_proficient_health_literacy\n",
|
| 383 |
+
"76_low_health_literacy\n",
|
| 384 |
+
"76_intermediate_health_literacy\n",
|
| 385 |
+
"76_proficient_health_literacy\n",
|
| 386 |
+
"77_low_health_literacy\n",
|
| 387 |
+
"77_intermediate_health_literacy\n",
|
| 388 |
+
"77_proficient_health_literacy\n",
|
| 389 |
+
"78_low_health_literacy\n",
|
| 390 |
+
"78_intermediate_health_literacy\n",
|
| 391 |
+
"78_proficient_health_literacy\n",
|
| 392 |
+
"79_low_health_literacy\n",
|
| 393 |
+
"79_intermediate_health_literacy\n",
|
| 394 |
+
"79_proficient_health_literacy\n"
|
| 395 |
+
]
|
| 396 |
+
}
|
| 397 |
+
],
|
| 398 |
+
"source": [
|
| 399 |
+
"# /home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_en_0_80_full_updated.json\n",
|
| 400 |
+
"with open('/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_en_0_80_full_updated.json', 'r') as f:\n",
|
| 401 |
+
" syn_data_diff_labels_en_0_80_full_updated = json.load(f)\n",
|
| 402 |
+
"map_data={}\n",
|
| 403 |
+
"for item in syn_data_diff_labels_en_0_80_full_updated:\n",
|
| 404 |
+
" for label in list(item['diff_label_texts'].keys()):\n",
|
| 405 |
+
" key=f\"{item['index']}_{label}\"\n",
|
| 406 |
+
" print(key)\n",
|
| 407 |
+
" map_data[key]={\n",
|
| 408 |
+
" 'doc_id':item['index'],\n",
|
| 409 |
+
" 'label':label,\n",
|
| 410 |
+
" 'fulltext':item['fulltext'],\n",
|
| 411 |
+
" \"diff_label_texts\":item['diff_label_texts'][label],\n",
|
| 412 |
+
" 'summary':item['summary']\n",
|
| 413 |
+
" }\n"
|
| 414 |
+
]
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"cell_type": "code",
|
| 418 |
+
"execution_count": 28,
|
| 419 |
+
"id": "c52e96ab",
|
| 420 |
+
"metadata": {},
|
| 421 |
+
"outputs": [],
|
| 422 |
+
"source": [
|
| 423 |
+
"# /home/mshahidul/readctrl/data/annotators_validate_data_(20_80)/combine/consolidated_ratings_0-20(not_all_category).json\n",
|
| 424 |
+
"with open('/home/mshahidul/readctrl/data/annotators_validate_data_(20_80)/combine/consolidated_ratings_0-20(not_all_category).json', 'r') as f:\n",
|
| 425 |
+
" consolidated_ratings_0_20 = json.load(f)\n",
|
| 426 |
+
"new_data=[]\n",
|
| 427 |
+
"for item in consolidated_ratings_0_20:\n",
|
| 428 |
+
" key=f\"{item['doc_id']}_{item['health_literacy_label']}\"\n",
|
| 429 |
+
" new_data.append({\n",
|
| 430 |
+
" **map_data[key],\n",
|
| 431 |
+
" })\n"
|
| 432 |
+
]
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"cell_type": "code",
|
| 436 |
+
"execution_count": 29,
|
| 437 |
+
"id": "bfd6cf96",
|
| 438 |
+
"metadata": {},
|
| 439 |
+
"outputs": [],
|
| 440 |
+
"source": [
|
| 441 |
+
"with open('/home/mshahidul/readctrl/data/annotators_validate_data_(20_80)/combine/verified_data_0-20.json', 'w') as f:\n",
|
| 442 |
+
" json.dump(new_data, f, indent=4)\n"
|
| 443 |
+
]
|
| 444 |
+
},
|
| 445 |
+
{
|
| 446 |
+
"cell_type": "code",
|
| 447 |
+
"execution_count": null,
|
| 448 |
+
"id": "cf797af6",
|
| 449 |
+
"metadata": {},
|
| 450 |
+
"outputs": [],
|
| 451 |
+
"source": []
|
| 452 |
+
}
|
| 453 |
+
],
|
| 454 |
+
"metadata": {
|
| 455 |
+
"kernelspec": {
|
| 456 |
+
"display_name": "un",
|
| 457 |
+
"language": "python",
|
| 458 |
+
"name": "python3"
|
| 459 |
+
},
|
| 460 |
+
"language_info": {
|
| 461 |
+
"codemirror_mode": {
|
| 462 |
+
"name": "ipython",
|
| 463 |
+
"version": 3
|
| 464 |
+
},
|
| 465 |
+
"file_extension": ".py",
|
| 466 |
+
"mimetype": "text/x-python",
|
| 467 |
+
"name": "python",
|
| 468 |
+
"nbconvert_exporter": "python",
|
| 469 |
+
"pygments_lexer": "ipython3",
|
| 470 |
+
"version": "3.11.14"
|
| 471 |
+
}
|
| 472 |
+
},
|
| 473 |
+
"nbformat": 4,
|
| 474 |
+
"nbformat_minor": 5
|
| 475 |
+
}
|
code/RL_model/unsloth_rl/claim_verifier.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import concurrent.futures
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
|
| 6 |
+
class MedicalClaimVerifier:
|
| 7 |
+
def __init__(self):
|
| 8 |
+
# OpenAI API configuration
|
| 9 |
+
api_file = "/home/mshahidul/api_new.json"
|
| 10 |
+
with open(api_file, "r") as f:
|
| 11 |
+
api_keys = json.load(f)
|
| 12 |
+
self.api_key = api_keys["openai"]
|
| 13 |
+
self.model_name = "gpt-5-mini"
|
| 14 |
+
self.client = OpenAI(api_key=self.api_key)
|
| 15 |
+
|
| 16 |
+
# Literacy ranges (IQR after outlier removal) from paper summary
|
| 17 |
+
# comp = completeness vs gold summary; cov = source_coverage vs full text
|
| 18 |
+
self.threshold_ranges = {
|
| 19 |
+
"low": {"comp": (0.9600, 1.0000), "cov": (0.1765, 0.3226)},
|
| 20 |
+
"intermediate": {"comp": (0.9393, 1.0000), "cov": (0.1818, 0.4091)},
|
| 21 |
+
"proficient": {"comp": (0.9231, 1.0000), "cov": (0.7725, 0.9347)},
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
# Minimum required information (upper bound of IQR)
|
| 25 |
+
self.thresholds = {
|
| 26 |
+
"low": {"comp": 1.0, "cov": 0.3226},
|
| 27 |
+
"intermediate": {"comp": 1.0, "cov": 0.4091},
|
| 28 |
+
"proficient": {"comp": 1.0, "cov": 0.9347},
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
def get_prompt(self,context,claim):
|
| 32 |
+
prompt = f"""
|
| 33 |
+
CONTEXT:
|
| 34 |
+
{context}
|
| 35 |
+
|
| 36 |
+
CLAIM TO VERIFY:
|
| 37 |
+
{claim}
|
| 38 |
+
|
| 39 |
+
INSTRUCTION:
|
| 40 |
+
Does the CONTEXT above provide enough evidence to support the CLAIM?
|
| 41 |
+
- Answer 'supported' if the claim is explicitly stated or logically followable.
|
| 42 |
+
- Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info.
|
| 43 |
+
|
| 44 |
+
Output only one word: 'supported' or 'not_supported'.
|
| 45 |
+
"""
|
| 46 |
+
return prompt
|
| 47 |
+
|
| 48 |
+
def check_support_api(self, prompt):
|
| 49 |
+
try:
|
| 50 |
+
response = self.client.chat.completions.create(
|
| 51 |
+
model=self.model_name,
|
| 52 |
+
messages=[{"role": "user", "content": prompt}],
|
| 53 |
+
)
|
| 54 |
+
res = response.choices[0].message.content.strip().lower()
|
| 55 |
+
# print("API Response:", res)
|
| 56 |
+
return 1.0 if "supported" in res and "not_supported" not in res else 0.0
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"API call error: {e}")
|
| 59 |
+
return 0.0
|
| 60 |
+
|
| 61 |
+
def evaluate_level(self, gen_text, gold_subs, full_subs, level_key):
|
| 62 |
+
"""Calculates scores for a single literacy level."""
|
| 63 |
+
if not gen_text: return 0.0, 0.0
|
| 64 |
+
|
| 65 |
+
# Run API calls in parallel to save time during RL
|
| 66 |
+
try:
|
| 67 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
| 68 |
+
# Completeness check (vs Gold Summary Subclaims)
|
| 69 |
+
comp_prompts = [self.get_prompt(gen_text, s) for s in gold_subs]
|
| 70 |
+
comp_results = list(executor.map(self.check_support_api, comp_prompts))
|
| 71 |
+
comp_score = sum(comp_results) / len(comp_results) if comp_results else 0.0
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Coverage check (vs Full Text Subclaims)
|
| 75 |
+
cov_prompts = [self.get_prompt(gen_text, s) for s in full_subs]
|
| 76 |
+
cov_results = list(executor.map(self.check_support_api, cov_prompts))
|
| 77 |
+
cov_score = sum(cov_results) / len(cov_results) if cov_results else 0.0
|
| 78 |
+
# print(f"Comp Score: {comp_score}, Cov Score: {cov_score} for {level_key}")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Parallel API call error: {e}")
|
| 81 |
+
return 0.0, 0.0
|
| 82 |
+
|
| 83 |
+
return comp_score, cov_score
|
| 84 |
+
|
| 85 |
+
import json
|
| 86 |
+
|
| 87 |
+
def get_reward_score(self, completion, gold_subs, full_subs):
|
| 88 |
+
data = None
|
| 89 |
+
|
| 90 |
+
# 1. Robust JSON Extraction
|
| 91 |
+
try:
|
| 92 |
+
# Clean potential markdown or whitespace
|
| 93 |
+
text = completion[0]['content'].strip().replace("```json", "").replace("```", "").strip()
|
| 94 |
+
data = json.loads(text)
|
| 95 |
+
except (json.JSONDecodeError, IndexError, ValueError) as e:
|
| 96 |
+
print("JSON Parsing Error in Reward Calculation")
|
| 97 |
+
# If all extraction attempts fail
|
| 98 |
+
return -5.0
|
| 99 |
+
|
| 100 |
+
# 2. Schema Validation
|
| 101 |
+
levels = ["low", "intermediate", "proficient"]
|
| 102 |
+
# Check if any required keys are missing
|
| 103 |
+
if not all(f"{lvl}_health_literacy" in data for lvl in levels):
|
| 104 |
+
return -2.0 # Slightly smaller penalty for partial formatting success
|
| 105 |
+
|
| 106 |
+
# 3. Scoring Logic
|
| 107 |
+
try:
|
| 108 |
+
total_reward = 0.0
|
| 109 |
+
pass_reward = 1.0
|
| 110 |
+
fail_penalty = -1.0
|
| 111 |
+
for lvl in levels:
|
| 112 |
+
gen_text = data.get(f"{lvl}_health_literacy", "")
|
| 113 |
+
|
| 114 |
+
# Skip scoring if text is empty
|
| 115 |
+
if not gen_text:
|
| 116 |
+
total_reward += fail_penalty
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
comp_score, cov_score = self.evaluate_level(gen_text, gold_subs, full_subs, lvl)
|
| 120 |
+
|
| 121 |
+
# Apply Thresholds
|
| 122 |
+
total_reward += pass_reward if comp_score >= self.thresholds[lvl]["comp"] else fail_penalty
|
| 123 |
+
total_reward += pass_reward if cov_score >= self.thresholds[lvl]["cov"] else fail_penalty
|
| 124 |
+
|
| 125 |
+
return total_reward
|
| 126 |
+
except Exception:
|
| 127 |
+
return -5.0
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# 1. Ground Truth Subclaims (Extracted from a medical paper on Hypertension)
|
| 131 |
+
gold_summary_subclaims = [
|
| 132 |
+
"Hypertension is defined as blood pressure above 140/90 mmHg.",
|
| 133 |
+
"Lifestyle changes like low salt intake can reduce blood pressure.",
|
| 134 |
+
"Diuretics are often the first line of pharmacological treatment."
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
full_text_subclaims = [
|
| 138 |
+
"Hypertension is defined as blood pressure above 140/90 mmHg.",
|
| 139 |
+
"Lifestyle changes like low salt intake can reduce blood pressure.",
|
| 140 |
+
"Diuretics are often the first line of pharmacological treatment.",
|
| 141 |
+
"The DASH diet emphasizes fruits, vegetables, and low-fat dairy.",
|
| 142 |
+
"Chronic hypertension increases the risk of stroke and myocardial infarction.",
|
| 143 |
+
"ACE inhibitors are contraindicated during pregnancy.",
|
| 144 |
+
"Secondary hypertension can be caused by renal artery stenosis."
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
# 2. Mock Model Completion (The output being evaluated)
|
| 148 |
+
# This mimics the format your RL environment would pass to the reward function
|
| 149 |
+
mock_completion = [{
|
| 150 |
+
'content': """
|
| 151 |
+
{
|
| 152 |
+
"low_health_literacy": "High blood pressure is when your blood is too strong for your veins. You should eat less salt to help stay healthy.",
|
| 153 |
+
"intermediate_health_literacy": "Hypertension is blood pressure over 140/90. You can lower it by eating less salt and taking water pills (diuretics) if your doctor says so.",
|
| 154 |
+
"proficient_health_literacy": "Hypertension (BP > 140/90 mmHg) is managed via lifestyle modifications like the DASH diet and salt restriction. Pharmacological interventions include diuretics as first-line therapy, though risks like stroke or heart attack persist if untreated. Secondary causes like renal artery stenosis should be screened, and ACE inhibitors must be avoided in pregnancy."
|
| 155 |
+
}
|
| 156 |
+
"""
|
| 157 |
+
}]
|
| 158 |
+
|
| 159 |
+
# Initialize your verifier
|
| 160 |
+
verifier = MedicalClaimVerifier()
|
| 161 |
+
|
| 162 |
+
# Test the reward calculation
|
| 163 |
+
reward = verifier.get_reward_score(
|
| 164 |
+
completion=mock_completion,
|
| 165 |
+
gold_subs=gold_summary_subclaims,
|
| 166 |
+
full_subs=full_text_subclaims
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
print(f"--- Evaluation Result ---")
|
| 170 |
+
print(f"Total Reward Score: {reward}")
|
| 171 |
+
|
| 172 |
+
# Logic Explanation:
|
| 173 |
+
# - Low: Likely fails 'comp' (missing 140/90 info), but might pass 'cov' (low threshold).
|
| 174 |
+
# - Intermediate: Likely passes 'comp' and 'cov'.
|
| 175 |
+
# - Proficient: Needs to cover almost all 7 subclaims to pass the 0.77 coverage threshold.
|
code/RL_model/unsloth_rl/finetune.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# Set GPU environment variables
|
| 3 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 4 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
| 5 |
+
import torch
|
| 6 |
+
from unsloth import FastLanguageModel
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from trl import SFTTrainer, SFTConfig
|
| 9 |
+
from unsloth.chat_templates import get_chat_template, standardize_data_formats, train_on_responses_only
|
| 10 |
+
|
| 11 |
+
# 1. Configuration
|
| 12 |
+
model_name = "unsloth/Qwen3-4B-Instruct-2507"
|
| 13 |
+
max_seq_length = 8192
|
| 14 |
+
dataset_path = "/home/mshahidul/readctrl/data/finetuning_data/training_data_readability_data_generation.json"
|
| 15 |
+
output_dir = "/home/mshahidul/readctrl_model/RL_model/readability_sft_lora_model"
|
| 16 |
+
|
| 17 |
+
# 2. Load Model and Tokenizer
|
| 18 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 19 |
+
model_name = model_name,
|
| 20 |
+
max_seq_length = max_seq_length,
|
| 21 |
+
load_in_4bit = True,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# 3. Add LoRA Adapters
|
| 25 |
+
model = FastLanguageModel.get_peft_model(
|
| 26 |
+
model,
|
| 27 |
+
r = 32,
|
| 28 |
+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
| 29 |
+
"gate_proj", "up_proj", "down_proj",],
|
| 30 |
+
lora_alpha = 32,
|
| 31 |
+
lora_dropout = 0,
|
| 32 |
+
bias = "none",
|
| 33 |
+
use_gradient_checkpointing = "unsloth",
|
| 34 |
+
random_state = 3407,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# 4. Data Preparation
|
| 38 |
+
tokenizer = get_chat_template(
|
| 39 |
+
tokenizer,
|
| 40 |
+
chat_template = "qwen3-instruct",
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
dataset = load_dataset("json", data_files = dataset_path, split = "train")
|
| 44 |
+
dataset = standardize_data_formats(dataset)
|
| 45 |
+
|
| 46 |
+
def formatting_prompts_func(examples):
|
| 47 |
+
convos = examples["conversations"]
|
| 48 |
+
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
|
| 49 |
+
return { "text" : texts, }
|
| 50 |
+
|
| 51 |
+
dataset = dataset.map(formatting_prompts_func, batched = True)
|
| 52 |
+
|
| 53 |
+
# 5. Training Setup
|
| 54 |
+
trainer = SFTTrainer(
|
| 55 |
+
model = model,
|
| 56 |
+
tokenizer = tokenizer,
|
| 57 |
+
train_dataset = dataset,
|
| 58 |
+
dataset_text_field = "text",
|
| 59 |
+
max_seq_length = max_seq_length,
|
| 60 |
+
args = SFTConfig(
|
| 61 |
+
per_device_train_batch_size = 2,
|
| 62 |
+
gradient_accumulation_steps = 4,
|
| 63 |
+
warmup_steps = 5,
|
| 64 |
+
# max_steps = 60, # Adjust as needed for your dataset size
|
| 65 |
+
num_train_epochs = 3,
|
| 66 |
+
learning_rate = 2e-4,
|
| 67 |
+
fp16 = not torch.cuda.is_bf16_supported(),
|
| 68 |
+
bf16 = torch.cuda.is_bf16_supported(),
|
| 69 |
+
logging_steps = 1,
|
| 70 |
+
optim = "adamw_8bit",
|
| 71 |
+
weight_decay = 0.01,
|
| 72 |
+
lr_scheduler_type = "linear",
|
| 73 |
+
seed = 3407,
|
| 74 |
+
output_dir = "outputs",
|
| 75 |
+
),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Train only on assistant responses
|
| 79 |
+
trainer = train_on_responses_only(
|
| 80 |
+
trainer,
|
| 81 |
+
instruction_part = "<|im_start|>user\n",
|
| 82 |
+
response_part = "<|im_start|>assistant\n",
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# 6. Train and Save
|
| 86 |
+
trainer.train()
|
| 87 |
+
|
| 88 |
+
model.save_pretrained(output_dir)
|
| 89 |
+
tokenizer.save_pretrained(output_dir)
|
| 90 |
+
|
| 91 |
+
print(f"Model saved to {output_dir}")
|
code/RL_model/unsloth_rl/health_classifier.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dspy
|
| 2 |
+
import json
|
| 3 |
+
from typing import Literal
|
| 4 |
+
|
| 5 |
+
# --- 1. LLM Configuration ---
|
| 6 |
+
def setup_dspy_classifier(save_path, api_key_path):
|
| 7 |
+
with open(api_key_path, "r") as f:
|
| 8 |
+
api_keys = json.load(f)
|
| 9 |
+
|
| 10 |
+
# Configure the LM
|
| 11 |
+
# Note: 'gpt-5-mini' is used per your configuration; ensure this matches your provider
|
| 12 |
+
openai_model = dspy.LM(model='gpt-5-mini', api_key=api_keys["openai"])
|
| 13 |
+
dspy.configure(lm=openai_model)
|
| 14 |
+
|
| 15 |
+
class HealthLiteracySignature(dspy.Signature):
|
| 16 |
+
"""
|
| 17 |
+
Judge the health literacy level of a generated medical summary.
|
| 18 |
+
Identify if the language is suitable for a layperson (low) or requires medical expertise (proficient).
|
| 19 |
+
"""
|
| 20 |
+
summary_text: str = dspy.InputField(desc="The generated medical summary to be analyzed.")
|
| 21 |
+
reasoning: str = dspy.OutputField(desc="Analysis of jargon, acronyms, and sentence complexity.")
|
| 22 |
+
label: Literal["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"] = dspy.OutputField()
|
| 23 |
+
|
| 24 |
+
class HealthLiteracyClassifier(dspy.Module):
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.predictor = dspy.ChainOfThought(HealthLiteracySignature)
|
| 28 |
+
|
| 29 |
+
def forward(self, summary_text):
|
| 30 |
+
return self.predictor(summary_text=summary_text)
|
| 31 |
+
|
| 32 |
+
# Initialize and load weights
|
| 33 |
+
classifier_instance = HealthLiteracyClassifier()
|
| 34 |
+
classifier_instance.load(save_path)
|
| 35 |
+
return classifier_instance
|
| 36 |
+
|
| 37 |
+
# Global instantiation (optional, or you can call setup in your main script)
|
| 38 |
+
API_FILE = "/home/mshahidul/api_new.json"
|
| 39 |
+
SAVE_PATH = "/home/mshahidul/readctrl/data/new_exp/optimized_health_classifier_gpt5-mini_v2.json"
|
| 40 |
+
|
| 41 |
+
# Create the instance to be imported
|
| 42 |
+
classifier = setup_dspy_classifier(SAVE_PATH, API_FILE)
|
code/RL_model/unsloth_rl/highlighter.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import AutoModel
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
# 1. Load the model (ensure you have transformers and torch installed)
|
| 6 |
+
print("Loading model... This may take a moment.")
|
| 7 |
+
model = AutoModel.from_pretrained(
|
| 8 |
+
"zilliz/semantic-highlight-bilingual-v1",
|
| 9 |
+
trust_remote_code=True
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
def process_and_highlight(question, context, threshold):
|
| 13 |
+
if not question or not context:
|
| 14 |
+
return "Please provide both a question and context."
|
| 15 |
+
|
| 16 |
+
# 2. Run the model inference
|
| 17 |
+
result = model.process(
|
| 18 |
+
question=question,
|
| 19 |
+
context=context,
|
| 20 |
+
threshold=threshold,
|
| 21 |
+
return_sentence_metrics=True
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
highlighted_sentences = result.get("highlighted_sentences", [])
|
| 25 |
+
|
| 26 |
+
# 3. Create the highlighted HTML output
|
| 27 |
+
# We iterate through the context and wrap highlighted sentences in HTML tags
|
| 28 |
+
output_html = context
|
| 29 |
+
|
| 30 |
+
# Sort highlighted sentences by length (descending) to avoid partial
|
| 31 |
+
# matching issues if one sentence is a substring of another
|
| 32 |
+
highlighted_sentences.sort(key=len, reverse=True)
|
| 33 |
+
|
| 34 |
+
for sent in highlighted_sentences:
|
| 35 |
+
# Use a bright yellow highlight style
|
| 36 |
+
style = "background-color: #fff176; color: #000; padding: 2px; border-radius: 3px; font-weight: 500;"
|
| 37 |
+
highlighted_tag = f'<span style="{style}">{sent}</span>'
|
| 38 |
+
output_html = output_html.replace(sent, highlighted_tag)
|
| 39 |
+
|
| 40 |
+
# Wrap in a container for better typography
|
| 41 |
+
final_output = f"""
|
| 42 |
+
<div style="font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; line-height: 1.8; font-size: 16px; color: #333;">
|
| 43 |
+
{output_html}
|
| 44 |
+
</div>
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
# 4. Format metrics for the display
|
| 48 |
+
metrics_str = "No specific probabilities returned."
|
| 49 |
+
if "sentence_probabilities" in result:
|
| 50 |
+
metrics_str = "\n".join([f"• {p:.4f}" for p in result["sentence_probabilities"]])
|
| 51 |
+
|
| 52 |
+
return final_output, metrics_str
|
| 53 |
+
|
| 54 |
+
# 5. Build the Gradio UI
|
| 55 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Semantic Highlighter") as demo:
|
| 56 |
+
gr.Markdown("# 🔍 Semantic Highlight Explorer")
|
| 57 |
+
gr.Markdown("Identify and highlight parts of a text that answer a specific question using the Zilliz bilingual model.")
|
| 58 |
+
|
| 59 |
+
with gr.Row():
|
| 60 |
+
with gr.Column(scale=1):
|
| 61 |
+
question_input = gr.Textbox(
|
| 62 |
+
label="Question",
|
| 63 |
+
placeholder="e.g., What are the symptoms of dehydration?",
|
| 64 |
+
lines=2
|
| 65 |
+
)
|
| 66 |
+
context_input = gr.Textbox(
|
| 67 |
+
label="Context / Full Text",
|
| 68 |
+
placeholder="Paste the document text here...",
|
| 69 |
+
lines=10
|
| 70 |
+
)
|
| 71 |
+
threshold_slider = gr.Slider(
|
| 72 |
+
minimum=0.1, maximum=1.0, value=0.5, step=0.05,
|
| 73 |
+
label="Confidence Threshold"
|
| 74 |
+
)
|
| 75 |
+
submit_btn = gr.Button("Analyze & Highlight", variant="primary")
|
| 76 |
+
|
| 77 |
+
with gr.Column(scale=1):
|
| 78 |
+
gr.Label("Highlighted Result")
|
| 79 |
+
output_display = gr.HTML()
|
| 80 |
+
|
| 81 |
+
with gr.Accordion("Sentence Metrics", open=False):
|
| 82 |
+
metrics_display = gr.Textbox(label="Probabilities", lines=5)
|
| 83 |
+
|
| 84 |
+
# Add example from your snippet
|
| 85 |
+
gr.Examples(
|
| 86 |
+
examples=[
|
| 87 |
+
[
|
| 88 |
+
"What are the symptoms of dehydration?",
|
| 89 |
+
"Dehydration occurs when your body loses more fluid than you take in. Common signs include feeling thirsty and having a dry mouth. The human body is composed of about 60% water. Dark yellow urine and infrequent urination are warning signs. Water is essential for many bodily functions. Dizziness, fatigue, and headaches can indicate severe dehydration.",
|
| 90 |
+
0.5
|
| 91 |
+
]
|
| 92 |
+
],
|
| 93 |
+
inputs=[question_input, context_input, threshold_slider]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
submit_btn.click(
|
| 97 |
+
fn=process_and_highlight,
|
| 98 |
+
inputs=[question_input, context_input, threshold_slider],
|
| 99 |
+
outputs=[output_display, metrics_display]
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
demo.launch(share=True)
|
code/RL_model/unsloth_rl/inference.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
# Set GPU environment variables
|
| 4 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 5 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
| 6 |
+
import torch
|
| 7 |
+
from unsloth import FastLanguageModel
|
| 8 |
+
from transformers import TextStreamer
|
| 9 |
+
|
| 10 |
+
# 1. Configuration
|
| 11 |
+
model_path = "/home/mshahidul/readctrl_model/RL_model/readability_sft_lora_model"
|
| 12 |
+
max_seq_length = 8192
|
| 13 |
+
|
| 14 |
+
# 2. Load the Fine-tuned Model and Tokenizer
|
| 15 |
+
# Unsloth automatically reloads the base Qwen3 model and attaches your adapters.
|
| 16 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 17 |
+
model_name = model_path,
|
| 18 |
+
max_seq_length = max_seq_length,
|
| 19 |
+
load_in_4bit = False,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# 3. Enable Fast Inference
|
| 23 |
+
# This activates Unsloth's optimized inference kernels for a 2x speedup.
|
| 24 |
+
FastLanguageModel.for_inference(model)
|
| 25 |
+
|
| 26 |
+
# 4. Prepare your Test Data
|
| 27 |
+
# Replace these with actual values from your evaluation set
|
| 28 |
+
gold_summary = "A 34-year-old pregnant woman presents with seizures and dysarthria and is urgently referred for a cranial MRI. The classic ‘Medusa head’ sign is seen and the diagnosis is made as a venous anomaly of development with peripheral partial thrombosis and proximal slow flow.\n"
|
| 29 |
+
fulltext = "We present the case of a 34-year-old woman, eight weeks pregnant with no other personal history of interest, who presents to the emergency department with generalized convulsions with dysarthria in the postcritical period, which resolve progressively in less than two hours. On physical examination, she is conscious, oriented, with no language or motor or sensory deficits. Only signs of a right lateral tongue bite are observed.\n\nThe complementary tests, such as blood tests or the electrocardiogram, are normal. Given that the episode corresponds with a first epileptic seizure and the patient is pregnant, an urgent magnetic resonance of the skull is requested.\n\nThe usual protocol was performed and 3D T1 sequences without and with intravenous contrast were obtained in axial, coronal and sagital planes, axial FLAIR, axial T2, VEN BOLD and magnetic susceptibility sequences, as well as axial diffusion and apparent diffusion coefficient map. The MRI identified multiple venous cortico-medullary vascular structures converging centripetally to a large central venous structure draining through the inferior anastomotic vein into the left transverse sinus, forming the classic ‘Medusa head’ sign. In the T1 sequences, the drainage vein was seen to be increased in signal with central hyphocaptation after contrast administration, suggesting partial thrombosis versus slow flow. In addition, in T2 and FLAIR sequences, the brain tissue surrounding the drainage vein was seen to be hyperintense, without diffusion restriction and compatible with edema.\n\nThese findings are suggestive of a venous anomaly of development with signs of partial peripheral thrombosis and slow flow more proximal, which cause edema of the surrounding tissue. She is started on clexane 60 mg/12 hours and levetiracetam 500 mg/12 hours and the patient shows improvement and symptomatic stability after one week.\n"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Define your exact system prompt
|
| 33 |
+
system_prompt = f"""
|
| 34 |
+
**System Role:**
|
| 35 |
+
|
| 36 |
+
You are an expert medical editor and Health Literacy specialist. Your task is to transform complex medical text into three distinct versions based on the reader's health literacy level. You must maintain the source language of the input while adjusting the linguistic complexity. Use the provided Gold Summary as the factual anchor to ensure the simplified versions remain accurate and focused on the most important information.
|
| 37 |
+
|
| 38 |
+
**User Prompt:**
|
| 39 |
+
|
| 40 |
+
Please process the following medical Source Text and its corresponding Gold Summary to generate three versions tailored to different health literacy levels.
|
| 41 |
+
### Instructions for Each Level:
|
| 42 |
+
|
| 43 |
+
1. Level: Low Health Literacy (High Readability)
|
| 44 |
+
|
| 45 |
+
Target: Individuals needing the simplest terms for immediate action.
|
| 46 |
+
|
| 47 |
+
Linguistic Goal: Use "living room" language. Replace all medical jargon with functional descriptions (e.g., "renal" becomes "kidney").
|
| 48 |
+
|
| 49 |
+
Information Density: Focus strictly on the "need-to-know" info found in the Gold Summary.
|
| 50 |
+
|
| 51 |
+
Strategy: High paraphrasing using analogies. One idea per sentence.
|
| 52 |
+
|
| 53 |
+
Faithfulness: Must align perfectly with the Gold Summary.
|
| 54 |
+
|
| 55 |
+
2. Level: Intermediate Health Literacy (Medium Readability)
|
| 56 |
+
|
| 57 |
+
Target: The general public (news-reading level).
|
| 58 |
+
|
| 59 |
+
Linguistic Goal: Standard vocabulary. Common medical terms are okay, but technical "doctor-speak" must be simplified.
|
| 60 |
+
|
| 61 |
+
Information Density: Balanced. Use the Gold Summary as the lead, supplemented by necessary context from the Source Text.
|
| 62 |
+
|
| 63 |
+
Strategy: Moderate paraphrasing. Remove minor technical details to avoid information overload.
|
| 64 |
+
|
| 65 |
+
Faithfulness: Maintains the main narrative of the Gold Summary.
|
| 66 |
+
|
| 67 |
+
3. Level: Proficient Health Literacy (Low Readability)
|
| 68 |
+
|
| 69 |
+
Target: Researchers, clinicians, or highly informed patients.
|
| 70 |
+
|
| 71 |
+
Linguistic Goal: Technical and academic language. Prioritize clinical nuance and medical accuracy.
|
| 72 |
+
|
| 73 |
+
Information Density: High. Use the Full Source Text to include data, physiological mechanisms, and statistics.
|
| 74 |
+
|
| 75 |
+
Strategy: Minimal paraphrasing. Retain all original technical terminology.
|
| 76 |
+
|
| 77 |
+
Faithfulness: Adhere to the Source Text; you may add related subclaims that provide deeper scientific context.
|
| 78 |
+
|
| 79 |
+
Input Language: English
|
| 80 |
+
Gold Summary (The Anchor):
|
| 81 |
+
{gold_summary}
|
| 82 |
+
Source Text (The Detail):
|
| 83 |
+
{fulltext}
|
| 84 |
+
|
| 85 |
+
**Output Format (JSON only):**
|
| 86 |
+
{{
|
| 87 |
+
"low_health_literacy": "...",
|
| 88 |
+
"intermediate_health_literacy": "...",
|
| 89 |
+
"proficient_health_literacy": "..."
|
| 90 |
+
}}
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
# Format for Qwen-3 instruction template
|
| 94 |
+
messages = [
|
| 95 |
+
{"role": "user", "content": system_prompt}
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
input_text = tokenizer.apply_chat_template(
|
| 99 |
+
messages,
|
| 100 |
+
tokenize = False,
|
| 101 |
+
add_generation_prompt = True,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
inputs = tokenizer([input_text], return_tensors = "pt").to("cuda")
|
| 105 |
+
|
| 106 |
+
# 5. Run Generation
|
| 107 |
+
# Using recommended sampling parameters for Qwen3 non-thinking mode.
|
| 108 |
+
text_streamer = TextStreamer(tokenizer, skip_prompt = True,skip_special_tokens = True)
|
| 109 |
+
|
| 110 |
+
print("--- Model Response ---")
|
| 111 |
+
_ = model.generate(
|
| 112 |
+
**inputs,
|
| 113 |
+
streamer = text_streamer,
|
| 114 |
+
max_new_tokens = 2048,
|
| 115 |
+
temperature = 0.7,
|
| 116 |
+
top_p = 0.8,
|
| 117 |
+
top_k = 20,
|
| 118 |
+
repetition_penalty = 1.05,
|
| 119 |
+
use_cache = True,
|
| 120 |
+
)
|
code/RL_model/unsloth_rl/prompt
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**System Role:**
|
| 2 |
+
|
| 3 |
+
You are an expert medical editor and Health Literacy specialist. Your task is to transform complex medical text into three distinct versions based on the reader's health literacy level. You must maintain the source language of the input while adjusting the linguistic complexity. Use the provided Gold Summary as the factual anchor to ensure the simplified versions remain accurate and focused on the most important information.
|
| 4 |
+
|
| 5 |
+
**User Prompt:**
|
| 6 |
+
|
| 7 |
+
Please process the following medical Source Text and its corresponding Gold Summary to generate three versions tailored to different health literacy levels.
|
| 8 |
+
### Instructions for Each Level:
|
| 9 |
+
|
| 10 |
+
1. Level: Low Health Literacy (High Readability)
|
| 11 |
+
|
| 12 |
+
Target: Individuals needing the simplest terms for immediate action.
|
| 13 |
+
|
| 14 |
+
Linguistic Goal: Use "living room" language. Replace all medical jargon with functional descriptions (e.g., "renal" becomes "kidney").
|
| 15 |
+
|
| 16 |
+
Information Density: Focus strictly on the "need-to-know" info found in the Gold Summary.
|
| 17 |
+
|
| 18 |
+
Strategy: High paraphrasing using analogies. One idea per sentence.
|
| 19 |
+
|
| 20 |
+
Faithfulness: Must align perfectly with the Gold Summary.
|
| 21 |
+
|
| 22 |
+
2. Level: Intermediate Health Literacy (Medium Readability)
|
| 23 |
+
|
| 24 |
+
Target: The general public (news-reading level).
|
| 25 |
+
|
| 26 |
+
Linguistic Goal: Standard vocabulary. Common medical terms are okay, but technical "doctor-speak" must be simplified.
|
| 27 |
+
|
| 28 |
+
Information Density: Balanced. Use the Gold Summary as the lead, supplemented by necessary context from the Source Text.
|
| 29 |
+
|
| 30 |
+
Strategy: Moderate paraphrasing. Remove minor technical details to avoid information overload.
|
| 31 |
+
|
| 32 |
+
Faithfulness: Maintains the main narrative of the Gold Summary.
|
| 33 |
+
|
| 34 |
+
3. Level: Proficient Health Literacy (Low Readability)
|
| 35 |
+
|
| 36 |
+
Target: Researchers, clinicians, or highly informed patients.
|
| 37 |
+
|
| 38 |
+
Linguistic Goal: Technical and academic language. Prioritize clinical nuance and medical accuracy.
|
| 39 |
+
|
| 40 |
+
Information Density: High. Use the Full Source Text to include data, physiological mechanisms, and statistics.
|
| 41 |
+
|
| 42 |
+
Strategy: Minimal paraphrasing. Retain all original technical terminology.
|
| 43 |
+
|
| 44 |
+
Faithfulness: Adhere to the Source Text; you may add related subclaims that provide deeper scientific context.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
I will provide the following information:
|
| 48 |
+
|
| 49 |
+
- Input Language: <<<SOURCE_LANGUAGE>>>
|
| 50 |
+
- Gold Summary (the anchor reference summary): <<<GOLD_SUMMARY>>>
|
| 51 |
+
- Source Text (detailed content): <<<FULL_TEXT>>>
|
| 52 |
+
|
| 53 |
+
**Output Format (JSON only):**
|
| 54 |
+
{{
|
| 55 |
+
"low_health_literacy": "...",
|
| 56 |
+
"intermediate_health_literacy": "...",
|
| 57 |
+
"proficient_health_literacy": "..."
|
| 58 |
+
}}
|
code/RL_model/unsloth_rl/reward_mock.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import concurrent.futures
|
| 5 |
+
from openai import OpenAI
|
| 6 |
+
|
| 7 |
+
class MedicalClaimVerifier:
|
| 8 |
+
def __init__(self):
|
| 9 |
+
# Implementation remains similar, but with safer error handling
|
| 10 |
+
api_file = "/home/mshahidul/api_new.json"
|
| 11 |
+
with open(api_file, "r") as f:
|
| 12 |
+
api_keys = json.load(f)
|
| 13 |
+
self.api_key = api_keys["openai"]
|
| 14 |
+
# Note: Ensure gpt-5-nano is actually available in your tier
|
| 15 |
+
self.model_name = "gpt-5-nano"
|
| 16 |
+
self.client = OpenAI(api_key=self.api_key)
|
| 17 |
+
|
| 18 |
+
self.thresholds = {
|
| 19 |
+
"low": {"comp": 1.0, "cov": 0.3226},
|
| 20 |
+
"intermediate": {"comp": 1.0, "cov": 0.4091},
|
| 21 |
+
"proficient": {"comp": 1.0, "cov": 0.9347},
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
def get_prompt(self,context,claim):
|
| 25 |
+
prompt = f"""
|
| 26 |
+
CONTEXT:
|
| 27 |
+
{context}
|
| 28 |
+
|
| 29 |
+
CLAIM TO VERIFY:
|
| 30 |
+
{claim}
|
| 31 |
+
|
| 32 |
+
INSTRUCTION:
|
| 33 |
+
Does the CONTEXT above provide enough evidence to support the CLAIM?
|
| 34 |
+
- Answer 'supported' if the claim is explicitly stated or logically followable.
|
| 35 |
+
- Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info.
|
| 36 |
+
|
| 37 |
+
Output only one word: 'supported' or 'not_supported'.
|
| 38 |
+
"""
|
| 39 |
+
return prompt
|
| 40 |
+
|
| 41 |
+
def check_support_api(self, prompt):
|
| 42 |
+
try:
|
| 43 |
+
response = self.client.chat.completions.create(
|
| 44 |
+
model=self.model_name,
|
| 45 |
+
messages=[{"role": "user", "content": prompt}],
|
| 46 |
+
)
|
| 47 |
+
res = response.choices[0].message.content.strip().lower()
|
| 48 |
+
return 1.0 if "supported" in res and "not_supported" not in res else 0.0
|
| 49 |
+
except Exception:
|
| 50 |
+
return 0.0
|
| 51 |
+
|
| 52 |
+
def evaluate_level(self, gen_text, gold_subs, full_subs):
|
| 53 |
+
if not gen_text or not gold_subs or not full_subs:
|
| 54 |
+
return 0.0, 0.0
|
| 55 |
+
|
| 56 |
+
# Combining calls to reduce overhead
|
| 57 |
+
all_claims = gold_subs + full_subs
|
| 58 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
| 59 |
+
results = list(executor.map(self.check_support_api, [self.get_prompt(gen_text, s) for s in all_claims]))
|
| 60 |
+
|
| 61 |
+
comp_results = results[:len(gold_subs)]
|
| 62 |
+
cov_results = results[len(gold_subs):]
|
| 63 |
+
|
| 64 |
+
comp_score = sum(comp_results) / len(gold_subs)
|
| 65 |
+
cov_score = sum(cov_results) / len(full_subs)
|
| 66 |
+
return comp_score, cov_score
|
| 67 |
+
|
| 68 |
+
verifier = MedicalClaimVerifier()
|
| 69 |
+
|
| 70 |
+
def compute_score(data_source, solution_str, ground_truth, extra_info=None):
|
| 71 |
+
gold_subs = ground_truth.get('summary_subclaims', [])
|
| 72 |
+
full_subs = ground_truth.get('fulltext_subclaims', [])
|
| 73 |
+
|
| 74 |
+
if not gold_subs or not full_subs:
|
| 75 |
+
return 0.0
|
| 76 |
+
|
| 77 |
+
# 1. Parsing with fallback
|
| 78 |
+
try:
|
| 79 |
+
cleaned_str = solution_str.strip()
|
| 80 |
+
if "```json" in cleaned_str:
|
| 81 |
+
cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip()
|
| 82 |
+
elif "```" in cleaned_str:
|
| 83 |
+
cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip()
|
| 84 |
+
data = json.loads(cleaned_str)
|
| 85 |
+
except Exception:
|
| 86 |
+
return -5.0
|
| 87 |
+
|
| 88 |
+
levels = ["low", "intermediate", "proficient"]
|
| 89 |
+
scores = {}
|
| 90 |
+
|
| 91 |
+
# 2. Score Calculation
|
| 92 |
+
for lvl in levels:
|
| 93 |
+
gen_text = data.get(f"{lvl}_health_literacy", "")
|
| 94 |
+
if not gen_text:
|
| 95 |
+
scores[lvl] = {"comp": 0.0, "cov": 0.0, "missing": True}
|
| 96 |
+
else:
|
| 97 |
+
comp, cov = verifier.evaluate_level(gen_text, gold_subs, full_subs)
|
| 98 |
+
scores[lvl] = {"comp": comp, "cov": cov, "missing": False}
|
| 99 |
+
|
| 100 |
+
# 3. Reward Shaping Logic
|
| 101 |
+
total_reward = 0.0
|
| 102 |
+
|
| 103 |
+
low_cov = scores["low"]["cov"]
|
| 104 |
+
int_cov = scores["intermediate"]["cov"]
|
| 105 |
+
pro_cov = scores["proficient"]["cov"]
|
| 106 |
+
|
| 107 |
+
# Soft Hierarchy Check: Reward progression, penalize stagnation
|
| 108 |
+
# Instead of -2.0 exit, we subtract if the order is wrong
|
| 109 |
+
hierarchy_penalty = 0.0
|
| 110 |
+
if not (low_cov <= int_cov <= pro_cov):
|
| 111 |
+
hierarchy_penalty = -2.0
|
| 112 |
+
|
| 113 |
+
for lvl in levels:
|
| 114 |
+
if scores[lvl]["missing"]:
|
| 115 |
+
total_reward -= 1.0 # Penalty per missing field
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
comp_s = scores[lvl]["comp"]
|
| 119 |
+
cov_s = scores[lvl]["cov"]
|
| 120 |
+
thresh = verifier.thresholds[lvl]
|
| 121 |
+
|
| 122 |
+
# Continuous Reward: (Actual - Threshold)
|
| 123 |
+
# This tells the model "You're 10% away" vs "You failed"
|
| 124 |
+
total_reward += (comp_s - thresh["comp"])
|
| 125 |
+
total_reward += (cov_s - thresh["cov"])
|
| 126 |
+
|
| 127 |
+
return total_reward + hierarchy_penalty
|
code/RL_model/unsloth_rl/test_reward_mock_unittest.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal, offline tests for reward_mock.py.
|
| 2 |
+
|
| 3 |
+
Run:
|
| 4 |
+
python code/RL_model/unsloth_rl/test_reward_mock_unittest.py
|
| 5 |
+
|
| 6 |
+
These tests avoid real OpenAI calls by:
|
| 7 |
+
- mocking the API key file read
|
| 8 |
+
- stubbing OpenAI client construction
|
| 9 |
+
- overriding verifier.evaluate_level to deterministic outputs
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import importlib.util
|
| 15 |
+
import sys
|
| 16 |
+
import types
|
| 17 |
+
import unittest
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from unittest.mock import mock_open, patch
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
THIS_DIR = Path(__file__).resolve().parent
|
| 23 |
+
REWARD_MOCK_PATH = THIS_DIR / "reward_mock.py"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class FakeOpenAI:
|
| 27 |
+
def __init__(self, api_key: str | None = None, **_kwargs):
|
| 28 |
+
self.api_key = api_key
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_reward_mock_module():
|
| 32 |
+
"""Load reward_mock.py from its file path under test-friendly patches."""
|
| 33 |
+
module_name = "reward_mock_under_test"
|
| 34 |
+
if module_name in sys.modules:
|
| 35 |
+
del sys.modules[module_name]
|
| 36 |
+
|
| 37 |
+
spec = importlib.util.spec_from_file_location(module_name, str(REWARD_MOCK_PATH))
|
| 38 |
+
if spec is None or spec.loader is None:
|
| 39 |
+
raise RuntimeError(f"Failed to create import spec for {REWARD_MOCK_PATH}")
|
| 40 |
+
|
| 41 |
+
module = importlib.util.module_from_spec(spec)
|
| 42 |
+
|
| 43 |
+
# Ensure 'openai' import is available and OpenAI ctor is patched.
|
| 44 |
+
# reward_mock does: `from openai import OpenAI`
|
| 45 |
+
with patch("builtins.open", mock_open(read_data='{"openai": "sk-test"}')):
|
| 46 |
+
with patch("openai.OpenAI", FakeOpenAI):
|
| 47 |
+
spec.loader.exec_module(module)
|
| 48 |
+
|
| 49 |
+
sys.modules[module_name] = module
|
| 50 |
+
return module
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TestRewardMockComputeScore(unittest.TestCase):
|
| 54 |
+
def test_valid_json_progression_no_hierarchy_penalty(self):
|
| 55 |
+
rm = load_reward_mock_module()
|
| 56 |
+
|
| 57 |
+
def fake_evaluate_level(gen_text, gold_subs, full_subs):
|
| 58 |
+
# Return (comp, cov) deterministically based on the generated text.
|
| 59 |
+
if gen_text == "LOW":
|
| 60 |
+
return 1.0, 0.3000
|
| 61 |
+
if gen_text == "INTER":
|
| 62 |
+
return 1.0, 0.4000
|
| 63 |
+
if gen_text == "PRO":
|
| 64 |
+
return 1.0, 0.9500
|
| 65 |
+
return 0.0, 0.0
|
| 66 |
+
|
| 67 |
+
rm.verifier.evaluate_level = fake_evaluate_level
|
| 68 |
+
|
| 69 |
+
solution_str = """```json
|
| 70 |
+
{
|
| 71 |
+
"low_health_literacy": "LOW",
|
| 72 |
+
"intermediate_health_literacy": "INTER",
|
| 73 |
+
"proficient_health_literacy": "PRO"
|
| 74 |
+
}
|
| 75 |
+
```"""
|
| 76 |
+
|
| 77 |
+
ground_truth = {
|
| 78 |
+
"summary_subclaims": ["a", "b"],
|
| 79 |
+
"fulltext_subclaims": ["x", "y", "z"],
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
score = rm.compute_score(data_source=None, solution_str=solution_str, ground_truth=ground_truth)
|
| 83 |
+
|
| 84 |
+
# comp thresholds are 1.0 -> comp deltas = 0
|
| 85 |
+
# cov deltas: (0.3000-0.3226) + (0.4000-0.4091) + (0.9500-0.9347) = -0.0164
|
| 86 |
+
self.assertAlmostEqual(score, -0.0164, places=4)
|
| 87 |
+
|
| 88 |
+
def test_missing_field_penalizes_and_triggers_hierarchy_penalty(self):
|
| 89 |
+
rm = load_reward_mock_module()
|
| 90 |
+
|
| 91 |
+
def fake_evaluate_level(gen_text, gold_subs, full_subs):
|
| 92 |
+
if gen_text == "LOW":
|
| 93 |
+
return 1.0, 0.3000
|
| 94 |
+
if gen_text == "PRO":
|
| 95 |
+
return 1.0, 0.9500
|
| 96 |
+
return 0.0, 0.0
|
| 97 |
+
|
| 98 |
+
rm.verifier.evaluate_level = fake_evaluate_level
|
| 99 |
+
|
| 100 |
+
# intermediate is missing => -1.0
|
| 101 |
+
# BUT its cov will be 0.0 for the hierarchy check, so low_cov(0.3) <= int_cov(0.0) fails => -2.0
|
| 102 |
+
solution_str = '{"low_health_literacy": "LOW", "proficient_health_literacy": "PRO"}'
|
| 103 |
+
|
| 104 |
+
ground_truth = {
|
| 105 |
+
"summary_subclaims": ["a"],
|
| 106 |
+
"fulltext_subclaims": ["x"],
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
score = rm.compute_score(data_source=None, solution_str=solution_str, ground_truth=ground_truth)
|
| 110 |
+
expected = (0.3000 - 0.3226) + (0.9500 - 0.9347) - 1.0 - 2.0
|
| 111 |
+
self.assertAlmostEqual(score, expected, places=4)
|
| 112 |
+
|
| 113 |
+
def test_invalid_json_returns_minus_five(self):
|
| 114 |
+
rm = load_reward_mock_module()
|
| 115 |
+
|
| 116 |
+
ground_truth = {
|
| 117 |
+
"summary_subclaims": ["a"],
|
| 118 |
+
"fulltext_subclaims": ["x"],
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
score = rm.compute_score(data_source=None, solution_str="not a json", ground_truth=ground_truth)
|
| 122 |
+
self.assertEqual(score, -5.0)
|
| 123 |
+
|
| 124 |
+
def test_missing_claims_returns_zero(self):
|
| 125 |
+
rm = load_reward_mock_module()
|
| 126 |
+
|
| 127 |
+
solution_str = '{"low_health_literacy": "LOW", "intermediate_health_literacy": "INTER", "proficient_health_literacy": "PRO"}'
|
| 128 |
+
|
| 129 |
+
# Missing subclaims => early return 0.0
|
| 130 |
+
score = rm.compute_score(
|
| 131 |
+
data_source=None,
|
| 132 |
+
solution_str=solution_str,
|
| 133 |
+
ground_truth={"summary_subclaims": [], "fulltext_subclaims": ["x"]},
|
| 134 |
+
)
|
| 135 |
+
self.assertEqual(score, 0.0)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
unittest.main(verbosity=2)
|
code/RL_model/unsloth_rl/testing.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import concurrent.futures
|
| 3 |
+
from unittest.mock import MagicMock
|
| 4 |
+
|
| 5 |
+
# --- The Class (Modified slightly for standalone demo) ---
|
| 6 |
+
|
| 7 |
+
class MedicalClaimVerifier:
|
| 8 |
+
def __init__(self, mock_mode=False):
|
| 9 |
+
self.thresholds = {
|
| 10 |
+
"low": {"comp": 0.6107, "cov": 0.3723},
|
| 11 |
+
"intermediate": {"comp": 0.8199, "cov": 0.6611},
|
| 12 |
+
"proficient": {"comp": 0.9569, "cov": 0.9069}
|
| 13 |
+
}
|
| 14 |
+
self.mock_mode = mock_mode
|
| 15 |
+
|
| 16 |
+
if not mock_mode:
|
| 17 |
+
from openai import OpenAI
|
| 18 |
+
self.api_url = "http://172.16.34.29:8004/v1"
|
| 19 |
+
self.client = OpenAI(base_url=self.api_url, api_key="EMPTY")
|
| 20 |
+
self.model_name = "qwen3-32b-readctrl"
|
| 21 |
+
|
| 22 |
+
def get_audit_prompt(self, literacy_level):
|
| 23 |
+
level_guidelines = {
|
| 24 |
+
"low_health_literacy": """
|
| 25 |
+
Level: Low Health Literacy (High Readability)
|
| 26 |
+
Target: Individuals needing simple terms.
|
| 27 |
+
Goal: 'Living room' language. Replace jargon (e.g., 'renal' -> 'kidney').
|
| 28 |
+
Density: Strictly 'need-to-know' info from Gold Summary.
|
| 29 |
+
Strategy: High paraphrasing, analogies, one idea per sentence.
|
| 30 |
+
Faithfulness: Must align with Gold Summary.""",
|
| 31 |
+
|
| 32 |
+
"intermediate_health_literacy": """
|
| 33 |
+
Level: Intermediate Health Literacy (Medium Readability)
|
| 34 |
+
Target: General public.
|
| 35 |
+
Goal: Standard vocabulary. Common medical terms okay; technical speak simplified.
|
| 36 |
+
Density: Balanced. Use Gold Summary as lead, supplemented by context from Source.
|
| 37 |
+
Strategy: Moderate paraphrasing. Remove minor technical details.
|
| 38 |
+
Faithfulness: Maintain main narrative of Gold Summary.""",
|
| 39 |
+
|
| 40 |
+
"proficient_health_literacy": """
|
| 41 |
+
Level: Proficient Health Literacy (Low Readability)
|
| 42 |
+
Target: Researchers/Clinicians.
|
| 43 |
+
Goal: Technical/Academic. Prioritize clinical nuance and accuracy.
|
| 44 |
+
Density: High. Include data, physiological mechanisms, and statistics from Source.
|
| 45 |
+
Strategy: Minimal paraphrasing. Retain original technical terminology.
|
| 46 |
+
Faithfulness: Adhere to Source Text; add deeper scientific context."""
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
guidelines = level_guidelines.get(literacy_level, "Follow standard medical audit practices.")
|
| 50 |
+
level_desc = literacy_level.replace("_", " ")
|
| 51 |
+
|
| 52 |
+
base_instructions = f"""
|
| 53 |
+
### Literacy Level Context:
|
| 54 |
+
{guidelines}
|
| 55 |
+
|
| 56 |
+
### Task Instructions:"""
|
| 57 |
+
return base_instructions
|
| 58 |
+
|
| 59 |
+
def get_completeness_prompt(self, generated_text, source_subclaim, literacy_level):
|
| 60 |
+
base_instructions = self.get_audit_prompt(literacy_level)
|
| 61 |
+
level_desc = literacy_level.replace("_", " ")
|
| 62 |
+
return f"""{base_instructions}
|
| 63 |
+
1. Determine whether this Fact from the Gold Standard is covered in the {level_desc} summary.
|
| 64 |
+
2. Mark 'supported' ONLY IF:
|
| 65 |
+
- The fact is explicitly stated in the summary, OR
|
| 66 |
+
- The fact is clearly paraphrased or simplified in a way that preserves its meaning.
|
| 67 |
+
3. Do NOT mark 'supported' based solely on omission.
|
| 68 |
+
- Absence of mention does NOT imply intentional exclusion.
|
| 69 |
+
- Negative or exclusionary facts (e.g., "no complications," "no family history," "no systemic signs") must be explicitly conveyed.
|
| 70 |
+
4. Mark 'not_supported' if:
|
| 71 |
+
- The fact is completely omitted, OR
|
| 72 |
+
- The summary discusses related information but does not confirm the specific fact.
|
| 73 |
+
5. Literacy-based simplification is allowed, but factual meaning must be preserved.
|
| 74 |
+
|
| 75 |
+
SUMMARY: {generated_text}
|
| 76 |
+
FACT: {source_subclaim}
|
| 77 |
+
|
| 78 |
+
output: 'supported' or 'not_supported'.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def get_source_coverage_prompt(self, generated_text, source_subclaim, literacy_level):
|
| 82 |
+
base_instructions = self.get_audit_prompt(literacy_level)
|
| 83 |
+
level_desc = literacy_level.replace("_", " ")
|
| 84 |
+
return f"""{base_instructions}
|
| 85 |
+
1. Check whether the following Fact from the ORIGINAL Source Text is explicitly covered in the generated {level_desc} summary.
|
| 86 |
+
2. Mark 'supported' ONLY IF:
|
| 87 |
+
- The summary clearly states the fact, OR
|
| 88 |
+
- The fact is conveyed through an explicit paraphrase or simplification that preserves its meaning.
|
| 89 |
+
3. Do NOT infer support from silence or omission.
|
| 90 |
+
- Absence of mention does NOT count as support.
|
| 91 |
+
- Especially for negative or exclusionary facts (e.g., "no family history," "no extra-renal signs," "no complications"), the summary must explicitly indicate absence.
|
| 92 |
+
4. Mark 'not_supported' if:
|
| 93 |
+
- The summary omits the fact entirely, OR
|
| 94 |
+
- The summary discusses related topics but does not clearly confirm the specific fact.
|
| 95 |
+
5. Simplification for literacy level is allowed, but factual meaning must be preserved.
|
| 96 |
+
|
| 97 |
+
GENERATED SUMMARY: {generated_text}
|
| 98 |
+
SOURCE FACT: {source_subclaim}
|
| 99 |
+
|
| 100 |
+
output: 'supported' or 'not_supported'."""
|
| 101 |
+
|
| 102 |
+
def check_support_api(self, prompt):
|
| 103 |
+
# print(f"Prompt Sent:\n{prompt}\n")
|
| 104 |
+
|
| 105 |
+
# Real logic
|
| 106 |
+
try:
|
| 107 |
+
response = self.client.chat.completions.create(
|
| 108 |
+
model=self.model_name,
|
| 109 |
+
messages=[{"role": "user", "content": prompt}],
|
| 110 |
+
max_tokens=300, temperature=0.1,
|
| 111 |
+
)
|
| 112 |
+
res = response.choices[0].message.content.strip().lower()
|
| 113 |
+
print(f"Response Received:\n{res}\n")
|
| 114 |
+
return 1.0 if "supported" in res and "not_supported" not in res else 0.0
|
| 115 |
+
except:
|
| 116 |
+
return 0.0
|
| 117 |
+
|
| 118 |
+
def evaluate_level(self, gen_text, gold_subs, full_subs, level_key):
|
| 119 |
+
if not gen_text: return 0.0, 0.0
|
| 120 |
+
|
| 121 |
+
# Using 2 workers for demo to avoid overhead
|
| 122 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
| 123 |
+
comp_prompts = [self.get_completeness_prompt(gen_text, s, level_key) for s in gold_subs]
|
| 124 |
+
comp_results = list(executor.map(self.check_support_api, comp_prompts))
|
| 125 |
+
comp_score = sum(comp_results) / len(comp_results) if comp_results else 0.0
|
| 126 |
+
|
| 127 |
+
cov_prompts = [self.get_source_coverage_prompt(gen_text, s, level_key) for s in full_subs]
|
| 128 |
+
cov_results = list(executor.map(self.check_support_api, cov_prompts))
|
| 129 |
+
cov_score = sum(cov_results) / len(cov_results) if cov_results else 0.0
|
| 130 |
+
|
| 131 |
+
return comp_score, cov_score
|
| 132 |
+
|
| 133 |
+
def get_reward_score(self, completion, gold_subs, full_subs):
|
| 134 |
+
data = None
|
| 135 |
+
try:
|
| 136 |
+
# completion[0]['content'] structure as expected by RL frameworks
|
| 137 |
+
text = completion[0]['content'].strip()
|
| 138 |
+
|
| 139 |
+
if "```json" in text:
|
| 140 |
+
text = text.split("```json")[-1].split("```")[0].strip()
|
| 141 |
+
elif "```" in text:
|
| 142 |
+
text = text.split("```")[-1].split("```")[0].strip()
|
| 143 |
+
|
| 144 |
+
if "<SOLUTION>" in text:
|
| 145 |
+
text = text.split("<SOLUTION>")[-1].split("</SOLUTION>")[0].strip()
|
| 146 |
+
|
| 147 |
+
data = json.loads(text)
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"JSON Parse Error: {e}")
|
| 150 |
+
return -5.0
|
| 151 |
+
|
| 152 |
+
levels = ["low", "intermediate", "proficient"]
|
| 153 |
+
if not all(f"{lvl}_health_literacy" in data for lvl in levels):
|
| 154 |
+
return -2.0
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
total_reward = 0.0
|
| 158 |
+
print("\n--- Evaluation Breakdown ---")
|
| 159 |
+
for lvl in levels:
|
| 160 |
+
gen_text = data.get(f"{lvl}_health_literacy", "")
|
| 161 |
+
comp_score, cov_score = self.evaluate_level(gen_text, gold_subs, full_subs, f"{lvl}_health_literacy")
|
| 162 |
+
|
| 163 |
+
# Logic check
|
| 164 |
+
comp_passed = comp_score >= self.thresholds[lvl]["comp"]
|
| 165 |
+
cov_passed = cov_score >= self.thresholds[lvl]["cov"]
|
| 166 |
+
|
| 167 |
+
total_reward += 1.0 if comp_passed else -0.5
|
| 168 |
+
total_reward += 1.0 if cov_passed else -0.5
|
| 169 |
+
|
| 170 |
+
print(f"[{lvl.upper()}] Comp: {comp_score:.2f} ({comp_passed}), Cov: {cov_score:.2f} ({cov_passed})")
|
| 171 |
+
|
| 172 |
+
return total_reward
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"Scoring Error: {e}")
|
| 175 |
+
return -5.0
|
| 176 |
+
|
| 177 |
+
# --- Execution Block ---
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
verifier = MedicalClaimVerifier(mock_mode=False)
|
| 181 |
+
|
| 182 |
+
# 1. Mock Input Data (what the model generated)
|
| 183 |
+
pass_completion = [{
|
| 184 |
+
"content": """
|
| 185 |
+
<SOLUTION>
|
| 186 |
+
{
|
| 187 |
+
"low_health_literacy": "This medicine makes it easier for your heart to pump and relaxes your blood tubes. You might feel dizzy if you stand up too fast.",
|
| 188 |
+
"intermediate_health_literacy": "ACE inhibitors like Lisinopril relax blood vessels to improve flow and lower heart attack risk. Side effects include low blood pressure.",
|
| 189 |
+
"proficient_health_literacy": "ACE inhibitors attenuate the effects of stress hormones on the myocardium while inducing vasodilation to reduce afterload and prevent myocardial infarction."
|
| 190 |
+
}
|
| 191 |
+
</SOLUTION>
|
| 192 |
+
"""
|
| 193 |
+
}]
|
| 194 |
+
|
| 195 |
+
# Completeness (Essential findings from a Gold Summary)
|
| 196 |
+
gold_subs = [
|
| 197 |
+
"ACE inhibitors help the heart pump better.",
|
| 198 |
+
"These medicines relax blood vessels.",
|
| 199 |
+
"Common side effects include dizziness and low blood pressure."
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
# Source Coverage (Detailed facts from the original Full Text)
|
| 203 |
+
full_subs = [
|
| 204 |
+
"Lisinopril is an example of an ACE inhibitor.",
|
| 205 |
+
"ACE inhibitors lower the risk of a heart attack.",
|
| 206 |
+
"The medication prevents stress hormones from damaging the heart.",
|
| 207 |
+
"Patients should stand up slowly to avoid dizziness."
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
# 3. Run Demo
|
| 211 |
+
print("Starting Demo Run...")
|
| 212 |
+
final_reward = verifier.get_reward_score(pass_completion, gold_subs, full_subs)
|
| 213 |
+
|
| 214 |
+
print("-" * 30)
|
| 215 |
+
print(f"FINAL REWARD SCORE: {final_reward}")
|
code/RL_model/unsloth_rl/testing_v2.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import concurrent.futures
|
| 3 |
+
from openai import OpenAI
|
| 4 |
+
|
| 5 |
+
class FactualityBenchmarker:
|
| 6 |
+
def __init__(self, api_url="http://172.16.34.29:8004/v1", model="qwen3-32b-readctrl"):
|
| 7 |
+
self.client = OpenAI(base_url=api_url, api_key="EMPTY")
|
| 8 |
+
self.model = model
|
| 9 |
+
|
| 10 |
+
def verify_claim(self, context, claim):
|
| 11 |
+
"""
|
| 12 |
+
Asks the model to determine if the context supports the claim.
|
| 13 |
+
"""
|
| 14 |
+
prompt = f"""
|
| 15 |
+
CONTEXT:
|
| 16 |
+
{context}
|
| 17 |
+
|
| 18 |
+
CLAIM TO VERIFY:
|
| 19 |
+
{claim}
|
| 20 |
+
|
| 21 |
+
INSTRUCTION:
|
| 22 |
+
Does the CONTEXT above provide enough evidence to support the CLAIM?
|
| 23 |
+
- Answer 'supported' if the claim is explicitly stated or logically followable.
|
| 24 |
+
- Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info.
|
| 25 |
+
|
| 26 |
+
Output only one word: 'supported' or 'not_supported'.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
response = self.client.chat.completions.create(
|
| 31 |
+
model=self.model,
|
| 32 |
+
messages=[{"role": "user", "content": prompt}],
|
| 33 |
+
temperature=0.0, # Zero temp for consistency in benchmarks
|
| 34 |
+
max_tokens=10
|
| 35 |
+
)
|
| 36 |
+
result = response.choices[0].message.content.strip().lower()
|
| 37 |
+
return "supported" if "supported" in result and "not_supported" not in result else "not_supported"
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"Error: {e}")
|
| 40 |
+
return "not_supported"
|
| 41 |
+
|
| 42 |
+
def run_evaluation(self, test_cases):
|
| 43 |
+
"""
|
| 44 |
+
Runs the benchmark over a list of test cases.
|
| 45 |
+
Each test case: {"context": "...", "claims": [{"text": "...", "label": 1.0/0.0}]}
|
| 46 |
+
"""
|
| 47 |
+
total_claims = 0
|
| 48 |
+
correct_predictions = 0
|
| 49 |
+
|
| 50 |
+
print(f"--- Starting Evaluation on {self.model} ---")
|
| 51 |
+
|
| 52 |
+
for i, case in enumerate(test_cases):
|
| 53 |
+
context = case["context"]
|
| 54 |
+
print(f"\nTest Case {i+1}:")
|
| 55 |
+
|
| 56 |
+
for claim_data in case["claims"]:
|
| 57 |
+
claim_text = claim_data["text"]
|
| 58 |
+
expected = claim_data["expected"]
|
| 59 |
+
|
| 60 |
+
# Model Prediction
|
| 61 |
+
prediction = self.verify_claim(context, claim_text)
|
| 62 |
+
|
| 63 |
+
is_correct = (prediction == expected)
|
| 64 |
+
if is_correct:
|
| 65 |
+
correct_predictions += 1
|
| 66 |
+
total_claims += 1
|
| 67 |
+
|
| 68 |
+
status = "PASS" if is_correct else "FAIL"
|
| 69 |
+
print(f" [{status}] Claim: {claim_text[:60]}... (Expected: {expected}, Got: {prediction})")
|
| 70 |
+
|
| 71 |
+
accuracy = (correct_predictions / total_claims) * 100 if total_claims > 0 else 0
|
| 72 |
+
print(f"\n" + "="*30)
|
| 73 |
+
print(f"FINAL ACCURACY: {accuracy:.2f}% ({correct_predictions}/{total_claims})")
|
| 74 |
+
print("="*30)
|
| 75 |
+
|
| 76 |
+
# --- Define your test data here ---
|
| 77 |
+
test_data = [
|
| 78 |
+
{
|
| 79 |
+
"context": """CASE PRESENTATION:
|
| 80 |
+
A 64-year-old male with a 15-year history of Type 2 Diabetes Mellitus and stage 3 chronic kidney disease (CKD)
|
| 81 |
+
presented to the emergency department with acute shortness of breath and peripheral edema. On physical
|
| 82 |
+
examination, the patient was hypertensive (175/95 mmHg) and tachycardic (110 bpm). Lung auscultation revealed
|
| 83 |
+
bilateral crackles in the lower lobes, consistent with pulmonary congestion. Notable laboratory findings
|
| 84 |
+
included a Serum Creatinine of 2.8 mg/dL (baseline 1.9 mg/dL) and a Brain Natriuretic Peptide (BNP) of 1,250 pg/mL.
|
| 85 |
+
|
| 86 |
+
Crucially, the patient reported no history of tobacco use and denied any chest pain or radiating pain to the
|
| 87 |
+
left arm. An EKG showed sinus tachycardia but no ST-segment elevation or T-wave inversion. The medical team
|
| 88 |
+
initiated a regimen of intravenous furosemide (40mg bolus) and transitioned the patient from his home
|
| 89 |
+
medication (Metformin) to insulin glargine to manage blood glucose during the acute episode, citing concerns
|
| 90 |
+
over lactic acidosis risk given the acute kidney injury. After 48 hours, the patient's oxygen saturation
|
| 91 |
+
improved from 89% on room air to 95%, and his weight decreased by 3.2 kg due to successful diuresis.
|
| 92 |
+
The discharge summary noted that despite the respiratory distress, there were no signs of systemic infection
|
| 93 |
+
or fever during the entire 4-day hospital stay.""",
|
| 94 |
+
"claims":[
|
| 95 |
+
# 1. Literal Extraction
|
| 96 |
+
{"text": "The patient has had Type 2 Diabetes for 15 years.", "expected": "supported"},
|
| 97 |
+
|
| 98 |
+
# 2. Medical Paraphrasing (Reading Control)
|
| 99 |
+
{"text": "The patient showed signs of fluid buildup in the lungs.", "expected": "supported"}, # 'bilateral crackles/congestion'
|
| 100 |
+
|
| 101 |
+
# 3. Negative Constraint (Exclusionary fact)
|
| 102 |
+
{"text": "The patient has a history of smoking.", "expected": "not_supported"}, # Text says 'no history of tobacco'
|
| 103 |
+
|
| 104 |
+
# 4. Mathematical Inference
|
| 105 |
+
{"text": "The patient's Serum Creatinine increased by 0.9 mg/dL from his baseline.", "expected": "supported"}, # 2.8 - 1.9 = 0.9
|
| 106 |
+
|
| 107 |
+
# 5. Logic: Cause and Effect
|
| 108 |
+
{"text": "The doctors stopped Metformin because of the risk of lactic acidosis.", "expected": "supported"},
|
| 109 |
+
|
| 110 |
+
# 6. Negative Finding (Testing 'Silence')
|
| 111 |
+
{"text": "The patient complained of pain moving down his left arm.", "expected": "not_supported"}, # Specifically denied
|
| 112 |
+
|
| 113 |
+
# 7. Vital Sign Interpretation
|
| 114 |
+
{"text": "The patient was experiencing high blood pressure and a fast heart rate upon arrival.", "expected": "supported"}, # 175/95 and 110bpm
|
| 115 |
+
|
| 116 |
+
# 8. Numerical Recovery
|
| 117 |
+
{"text": "The patient lost over 3 kilograms during the first two days of treatment.", "expected": "supported"}, # 3.2 kg
|
| 118 |
+
|
| 119 |
+
# 9. Complex Inference (EKG interpretation)
|
| 120 |
+
{"text": "The EKG provided clear evidence of an active heart attack.", "expected": "not_supported"}, # Text says 'no ST-elevation'
|
| 121 |
+
|
| 122 |
+
# 10. Systemic Health Status
|
| 123 |
+
{"text": "The patient remained afebrile throughout the hospitalization.", "expected": "supported"} # 'no fever' = afebrile
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"context": "The company reported a 15% increase in revenue, reaching $2 billion this quarter. However, net profit dropped due to high R&D costs.",
|
| 128 |
+
"claims": [
|
| 129 |
+
{"text": "Revenue reached $2 billion.", "expected": "supported"},
|
| 130 |
+
{"text": "Net profit increased this quarter.", "expected": "not_supported"},
|
| 131 |
+
{"text": "Spending on Research and Development impacted profits.", "expected": "supported"}
|
| 132 |
+
]
|
| 133 |
+
}
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
benchmarker = FactualityBenchmarker()
|
| 138 |
+
benchmarker.run_evaluation(test_data)
|
code/RL_model/verl/Search-R1/.gitignore
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**/*.pt
|
| 2 |
+
**/checkpoints
|
| 3 |
+
**/wget-log
|
| 4 |
+
**/_build/
|
| 5 |
+
**/*.ckpt
|
| 6 |
+
**/outputs
|
| 7 |
+
**/*.tar.gz
|
| 8 |
+
**/playground
|
| 9 |
+
**/wandb
|
| 10 |
+
|
| 11 |
+
# Byte-compiled / optimized / DLL files
|
| 12 |
+
__pycache__/
|
| 13 |
+
*.py[cod]
|
| 14 |
+
*$py.class
|
| 15 |
+
dataset/*
|
| 16 |
+
tensorflow/my_graph/*
|
| 17 |
+
.idea/
|
| 18 |
+
# C extensions
|
| 19 |
+
*.so
|
| 20 |
+
data
|
| 21 |
+
sft/output/*
|
| 22 |
+
sft/data/*
|
| 23 |
+
|
| 24 |
+
# Distribution / packaging
|
| 25 |
+
.Python
|
| 26 |
+
build/
|
| 27 |
+
develop-eggs/
|
| 28 |
+
dist/
|
| 29 |
+
downloads/
|
| 30 |
+
eggs/
|
| 31 |
+
.eggs/
|
| 32 |
+
lib/
|
| 33 |
+
lib64/
|
| 34 |
+
parts/
|
| 35 |
+
sdist/
|
| 36 |
+
var/
|
| 37 |
+
*.egg-info/
|
| 38 |
+
.installed.cfg
|
| 39 |
+
*.egg
|
| 40 |
+
|
| 41 |
+
# PyInstaller
|
| 42 |
+
# Usually these files are written by a python script from a template
|
| 43 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 44 |
+
*.manifest
|
| 45 |
+
*.spec
|
| 46 |
+
|
| 47 |
+
# Installer logs
|
| 48 |
+
pip-log.txt
|
| 49 |
+
pip-delete-this-directory.txt
|
| 50 |
+
|
| 51 |
+
# Unit test / coverage reports
|
| 52 |
+
htmlcov/
|
| 53 |
+
.tox/
|
| 54 |
+
.coverage
|
| 55 |
+
.coverage.*
|
| 56 |
+
.cache
|
| 57 |
+
nosetests.xml
|
| 58 |
+
coverage.xml
|
| 59 |
+
*,cover
|
| 60 |
+
.hypothesis/
|
| 61 |
+
|
| 62 |
+
# Translations
|
| 63 |
+
*.mo
|
| 64 |
+
*.pot
|
| 65 |
+
|
| 66 |
+
# Django stuff:
|
| 67 |
+
*.log
|
| 68 |
+
local_settings.py
|
| 69 |
+
|
| 70 |
+
image_outputs
|
| 71 |
+
|
| 72 |
+
checkpoints
|
| 73 |
+
|
| 74 |
+
# Flask stuff:
|
| 75 |
+
instance/
|
| 76 |
+
.webassets-cache
|
| 77 |
+
|
| 78 |
+
# Scrapy stuff:
|
| 79 |
+
.scrapy
|
| 80 |
+
|
| 81 |
+
# Sphinx documentation
|
| 82 |
+
docs/_build/
|
| 83 |
+
|
| 84 |
+
# PyBuilder
|
| 85 |
+
target/
|
| 86 |
+
|
| 87 |
+
# IPython Notebook
|
| 88 |
+
.ipynb_checkpoints
|
| 89 |
+
|
| 90 |
+
# pyenv
|
| 91 |
+
.python-version
|
| 92 |
+
|
| 93 |
+
# celery beat schedule file
|
| 94 |
+
celerybeat-schedule
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# virtualenv
|
| 98 |
+
venv/
|
| 99 |
+
ENV/
|
| 100 |
+
|
| 101 |
+
# Spyder project settings
|
| 102 |
+
.spyderproject
|
| 103 |
+
|
| 104 |
+
# Rope project settings
|
| 105 |
+
.ropeproject
|
| 106 |
+
|
| 107 |
+
# vscode
|
| 108 |
+
.vscode
|
| 109 |
+
|
| 110 |
+
# Mac
|
| 111 |
+
.DS_Store
|
| 112 |
+
|
| 113 |
+
# output logs
|
| 114 |
+
tests/e2e/toy_examples/deepspeed/synchronous/output.txt
|
| 115 |
+
|
| 116 |
+
# vim
|
| 117 |
+
*.swp
|
| 118 |
+
|
| 119 |
+
# log*
|
| 120 |
+
log/
|
| 121 |
+
|
| 122 |
+
**logs
|
code/RL_model/verl/Search-R1/LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright [yyyy] [name of copyright owner]
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
code/RL_model/verl/Search-R1/Notice.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Copyright 2023-2024 Bytedance Ltd. and/or its affiliates
|
code/RL_model/verl/Search-R1/README.md
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Search-R1: Train your LLMs to reason and call a search engine with reinforcement learning
|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
<img src="https://raw.githubusercontent.com/PeterGriffinJin/Search-R1/main/public/logo.png" alt="logo" width="300"/>
|
| 5 |
+
</div>
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
<a href="https://arxiv.org/abs/2503.09516">
|
| 9 |
+
<img src="https://img.shields.io/badge/Paper1-blue?style=for-the-badge" alt="Button1"/>
|
| 10 |
+
</a>
|
| 11 |
+
<a href="https://arxiv.org/abs/2505.15117">
|
| 12 |
+
<img src="https://img.shields.io/badge/Paper2-green?style=for-the-badge" alt="Button2"/>
|
| 13 |
+
</a>
|
| 14 |
+
<a href="https://huggingface.co/collections/PeterJinGo/search-r1-67d1a021202731cb065740f5">
|
| 15 |
+
<img src="https://img.shields.io/badge/Resources-orange?style=for-the-badge" alt="Button3"/>
|
| 16 |
+
</a>
|
| 17 |
+
<a href="https://x.com/BowenJin13/status/1895544294473109889">
|
| 18 |
+
<img src="https://img.shields.io/badge/Tweet-red?style=for-the-badge" alt="Button4"/>
|
| 19 |
+
</a>
|
| 20 |
+
<a href="https://wandb.ai/peterjin/Search-R1-v0.2">
|
| 21 |
+
<img src="https://img.shields.io/badge/Logs-purple?style=for-the-badge" alt="Button5"/>
|
| 22 |
+
</a>
|
| 23 |
+
</p>
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
<!-- <strong>Search-R1</strong> is a reinforcement learning framework for <em>training reasoning and searching (tool-call) interleaved LLMs</em>. -->
|
| 27 |
+
<!-- We built upon [veRL](https://github.com/volcengine/verl). -->
|
| 28 |
+
**Search-R1** is a reinforcement learning framework designed for training **reasoning-and-searching interleaved LLMs**—language models that learn to reason and make tool calls (e.g., to search engines) in a coordinated manner.
|
| 29 |
+
|
| 30 |
+
<!-- It can be seen as an extension of <strong>DeepSeek-R1(-Zero)</strong> with interleaved search engine calling and an opensource RL training-based solution for <strong>OpenAI DeepResearch</strong>. -->
|
| 31 |
+
Built upon [veRL](https://github.com/volcengine/verl), Search-R1 extends the ideas of **DeepSeek-R1(-Zero)** by incorporating interleaved search engine access and provides a fully open-source RL training pipeline. It serves as an alternative and open solution to **OpenAI DeepResearch**, enabling research and development in tool-augmented LLM reasoning.
|
| 32 |
+
|
| 33 |
+
<!-- Through RL (rule-based outcome reward), the 3B **base** LLM (both Qwen2.5-3b-base and Llama3.2-3b-base) develops reasoning and search engine calling abilities all on its own. -->
|
| 34 |
+
|
| 35 |
+
We support different RL methods (e.g., PPO, GRPO, reinforce), different LLMs (e.g., llama3, Qwen2.5, etc) and different search engines (e.g., local sparse/dense retrievers and online search engines).
|
| 36 |
+
|
| 37 |
+
Paper: [link1](https://arxiv.org/pdf/2503.09516), [link2](https://arxiv.org/abs/2505.15117); Model and data: [link](https://huggingface.co/collections/PeterJinGo/search-r1-67d1a021202731cb065740f5); Twitter thread: [link](https://x.com/BowenJin13/status/1895544294473109889); Full experiment log: [prelim](https://wandb.ai/peterjin/Search-R1-open); [v0.1](https://wandb.ai/peterjin/Search-R1-nq_hotpotqa_train); [v0.2](https://wandb.ai/peterjin/Search-R1-v0.2); [v0.3](https://wandb.ai/peterjin/Search-R1-v0.3). Details about these logs and methods can be find [here](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/experiment_log.md).
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+

|
| 41 |
+
|
| 42 |
+
## News
|
| 43 |
+
|
| 44 |
+
- [2025.10] Search-R1 is featured by Thinking Machines Lab's first product [Tinker](https://github.com/thinking-machines-lab/tinker-cookbook)! Details: [Document](https://github.com/thinking-machines-lab/tinker-cookbook/tree/main/tinker_cookbook/recipes/tool_use/search).
|
| 45 |
+
- [2025.7] Search-R1 is supported by [SkyRL](https://github.com/NovaSky-AI/SkyRL)! Detailed instructions: [code](https://github.com/NovaSky-AI/SkyRL/tree/main/skyrl-train/examples/search), [Document](https://novasky-ai.notion.site/skyrl-searchr1).
|
| 46 |
+
- [2025.6] Search-R1 is now integrated into the latest version of veRL and can take advantage of its most up-to-date features! Detailed instructions: [veRL](https://verl.readthedocs.io/en/latest/sglang_multiturn/search_tool_example.html), [English Document](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/verl-multiturn-searchR1-like.md), [Chinese Document](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/verl-multiturn-searchR1-like_ZH.md).
|
| 47 |
+
- [2025.5] The second [paper](https://arxiv.org/abs/2505.15117) conducting detailed empirical studies is published with logs: [v0.3](https://wandb.ai/peterjin/Search-R1-v0.3).
|
| 48 |
+
- [2025.4] We support [multinode](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/multinode.md) training for 30B+ LLMs!
|
| 49 |
+
- [2025.4] We support [different search engines](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/retriever.md) including sparse local retriever, dense local retriever with ANN indexing and online search engines!
|
| 50 |
+
- [2025.3] The first Search-R1 [paper](https://arxiv.org/pdf/2503.09516) is published with the logs: [v0.1](https://wandb.ai/peterjin/Search-R1-nq_hotpotqa_train); [v0.2](https://wandb.ai/peterjin/Search-R1-v0.2).
|
| 51 |
+
- [2025.2] We opensource Search-R1 codebase with [preliminary results](https://wandb.ai/peterjin/Search-R1-open).
|
| 52 |
+
|
| 53 |
+
## Links
|
| 54 |
+
|
| 55 |
+
- [Installation](#installation)
|
| 56 |
+
- [Quick start](#quick-start)
|
| 57 |
+
- [Preliminary results](#preliminary-results)
|
| 58 |
+
- [Inference](#inference)
|
| 59 |
+
- [Use your own dataset](#use-your-own-dataset)
|
| 60 |
+
- [Use your own search engine](#use-your-own-search-engine)
|
| 61 |
+
- [Features](#features)
|
| 62 |
+
- [Ackowledge](#acknowledge)
|
| 63 |
+
- [Citations](#citations)
|
| 64 |
+
|
| 65 |
+
## Installation
|
| 66 |
+
|
| 67 |
+
### Search-r1 environment
|
| 68 |
+
```bash
|
| 69 |
+
conda create -n searchr1 python=3.9
|
| 70 |
+
conda activate searchr1
|
| 71 |
+
# install torch [or you can skip this step and let vllm to install the correct version for you]
|
| 72 |
+
pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121
|
| 73 |
+
# install vllm
|
| 74 |
+
pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1
|
| 75 |
+
|
| 76 |
+
# verl
|
| 77 |
+
pip install -e .
|
| 78 |
+
|
| 79 |
+
# flash attention 2
|
| 80 |
+
pip3 install flash-attn --no-build-isolation
|
| 81 |
+
pip install wandb
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### Retriever environment (optional)
|
| 85 |
+
If you would like to call a local retriever as the search engine, you can install the environment as follows. (We recommend using a seperate environment.)
|
| 86 |
+
```bash
|
| 87 |
+
conda create -n retriever python=3.10
|
| 88 |
+
conda activate retriever
|
| 89 |
+
|
| 90 |
+
# we recommend installing torch with conda for faiss-gpu
|
| 91 |
+
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia
|
| 92 |
+
pip install transformers datasets pyserini
|
| 93 |
+
|
| 94 |
+
## install the gpu version faiss to guarantee efficient RL rollout
|
| 95 |
+
conda install -c pytorch -c nvidia faiss-gpu=1.8.0
|
| 96 |
+
|
| 97 |
+
## API function
|
| 98 |
+
pip install uvicorn fastapi
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
## Quick start
|
| 103 |
+
|
| 104 |
+
Train a reasoning + search LLM on NQ dataset with e5 as the retriever and wikipedia as the corpus.
|
| 105 |
+
|
| 106 |
+
(1) Download the indexing and corpus.
|
| 107 |
+
```bash
|
| 108 |
+
save_path=/the/path/to/save
|
| 109 |
+
python scripts/download.py --save_path $save_path
|
| 110 |
+
cat $save_path/part_* > $save_path/e5_Flat.index
|
| 111 |
+
gzip -d $save_path/wiki-18.jsonl.gz
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
(2) Process the NQ dataset.
|
| 115 |
+
```bash
|
| 116 |
+
python scripts/data_process/nq_search.py
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
(3) Launch a local retrieval server.
|
| 120 |
+
```bash
|
| 121 |
+
conda activate retriever
|
| 122 |
+
bash retrieval_launch.sh
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
(4) Run RL training (PPO) with Llama-3.2-3b-base.
|
| 126 |
+
```bash
|
| 127 |
+
conda activate searchr1
|
| 128 |
+
bash train_ppo.sh
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
## Preliminary results
|
| 132 |
+
|
| 133 |
+
(1) The base model (llama3.2-3b-base) learns to call the search engine and obtain improved performance.
|
| 134 |
+
|
| 135 |
+

|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
(2) The base model (Qwen2.5-7b-base) can learn to conduct multi-turn search engine calling and reasoning with RL.
|
| 139 |
+
|
| 140 |
+

|
| 141 |
+
|
| 142 |
+
## Inference
|
| 143 |
+
#### You can play with the trained Search-R1 model with your own question.
|
| 144 |
+
(1) Launch a local retrieval server.
|
| 145 |
+
```bash
|
| 146 |
+
conda activate retriever
|
| 147 |
+
bash retrieval_launch.sh
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
(2) Run inference.
|
| 151 |
+
```bash
|
| 152 |
+
conda activate searchr1
|
| 153 |
+
python infer.py
|
| 154 |
+
```
|
| 155 |
+
You can modify the ```question``` on line 7 to something you're interested in.
|
| 156 |
+
|
| 157 |
+
## Use your own dataset
|
| 158 |
+
|
| 159 |
+
### QA data
|
| 160 |
+
For each question-answer sample, it should be a dictionary containing the desired content as below:
|
| 161 |
+
|
| 162 |
+
```
|
| 163 |
+
data = {
|
| 164 |
+
"data_source": data_source,
|
| 165 |
+
"prompt": [{
|
| 166 |
+
"role": "user",
|
| 167 |
+
"content": question,
|
| 168 |
+
}],
|
| 169 |
+
"ability": "fact-reasoning",
|
| 170 |
+
"reward_model": {
|
| 171 |
+
"style": "rule",
|
| 172 |
+
"ground_truth": solution
|
| 173 |
+
},
|
| 174 |
+
"extra_info": {
|
| 175 |
+
'split': split,
|
| 176 |
+
'index': idx,
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
You can refer to ```scripts/data_process/nq_search.py``` for a concrete data processing example.
|
| 182 |
+
|
| 183 |
+
### Corpora
|
| 184 |
+
|
| 185 |
+
It is recommended to make your corpus a jsonl file, where each line (a dictionary with "id" key and "contents" key) corresponds to one passage. You can refer to ```example/corpus.jsonl``` for an example.
|
| 186 |
+
|
| 187 |
+
The "id" key corresponds to the passage id, while the "contents" key corresponds to the passage content ('"' + title + '"\n' + text).
|
| 188 |
+
For example:
|
| 189 |
+
```
|
| 190 |
+
{"id": "0", "contents": "Evan Morris Evan L. Morris (January 26, 1977 \u2013 July 9, 2015) was a lobbyist for Genentech and its parent corporation Roche in Washington."}
|
| 191 |
+
...
|
| 192 |
+
{"id": "100", "contents": "Three years later, when the United States Exploring Expedition to little-known portions of the globe was organised under Charles Wilkes, Hale was recommended, while yet an undergraduate."}
|
| 193 |
+
...
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
**Index your corpora (optional).**
|
| 197 |
+
If you would like to use a local retriever as the search engine, you can index your own corpus by:
|
| 198 |
+
```
|
| 199 |
+
bash search_r1/search/build_index.sh
|
| 200 |
+
```
|
| 201 |
+
You can change ```retriever_name``` and ```retriever_model``` to your interested off-the-shelf retriever.
|
| 202 |
+
|
| 203 |
+
## Use your own search engine
|
| 204 |
+
|
| 205 |
+
Our codebase supports local sparse retriever (e.g., BM25), local dense retriever (both flat indexing with GPUs and ANN indexing with CPUs) and online search engine (e.g., Google, Bing, etc). More details can be found [here](https://github.com/PeterGriffinJin/Search-R1/tree/main/docs/retriever.md).
|
| 206 |
+
|
| 207 |
+
The main philosophy is to launch a local or remote search engine server separately from the main RL training pipeline.
|
| 208 |
+
|
| 209 |
+
The LLM can call the search engine by calling the search API (e.g., "http://127.0.0.1:8000/retrieve").
|
| 210 |
+
|
| 211 |
+
You can refer to ```search_r1/search/retriever_server.py``` for an example of launching a local retriever server.
|
| 212 |
+
|
| 213 |
+
## Features
|
| 214 |
+
- Support local sparse retrievers (e.g., BM25). ✔️
|
| 215 |
+
- Support local dense retrievers (both flat indexing and ANN indexing) ✔️
|
| 216 |
+
- Support google search / bing search / brave search API and others. ✔️
|
| 217 |
+
- Support off-the-shelf neural rerankers. ✔️
|
| 218 |
+
- Support different RL methods (e.g., PPO, GRPO, reinforce). ✔️
|
| 219 |
+
- Support different LLMs (e.g., llama3, Qwen2.5, etc). ✔️
|
| 220 |
+
|
| 221 |
+
## Acknowledge
|
| 222 |
+
|
| 223 |
+
The concept of Search-R1 is inspired by [Deepseek-R1](https://github.com/deepseek-ai/DeepSeek-R1) and [TinyZero](https://github.com/Jiayi-Pan/TinyZero/tree/main).
|
| 224 |
+
Its implementation is built upon [veRL](https://github.com/volcengine/verl) and [RAGEN](https://github.com/ZihanWang314/RAGEN/tree/main).
|
| 225 |
+
We sincerely appreciate the efforts of these teams for their contributions to open-source research and development.
|
| 226 |
+
|
| 227 |
+
## Awesome work powered or inspired by Search-R1
|
| 228 |
+
|
| 229 |
+
- [DeepResearcher](https://github.com/GAIR-NLP/DeepResearcher): Scaling Deep Research via Reinforcement Learning in Real-world Environments. [![[code]](https://img.shields.io/github/stars/GAIR-NLP/DeepResearcher)](https://github.com/GAIR-NLP/DeepResearcher)
|
| 230 |
+
- [Multimodal-Search-R1](https://github.com/EvolvingLMMs-Lab/multimodal-search-r1): Incentivizing LMMs to Search. [![[code]](https://img.shields.io/github/stars/EvolvingLMMs-Lab/multimodal-search-r1)](https://github.com/EvolvingLMMs-Lab/multimodal-search-r1)
|
| 231 |
+
- [OTC](https://arxiv.org/pdf/2504.14870): Optimal Tool Calls via Reinforcement Learning.
|
| 232 |
+
- [ZeroSearch](https://github.com/Alibaba-NLP/ZeroSearch): Incentivize the Search Capability of LLMs without Searching. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/ZeroSearch)](https://github.com/Alibaba-NLP/ZeroSearch)
|
| 233 |
+
- [IKEA](https://github.com/hzy312/knowledge-r1): Reinforced Internal-External Knowledge Synergistic Reasoning for Efficient Adaptive Search Agent. [![[code]](https://img.shields.io/github/stars/hzy312/knowledge-r1)](https://github.com/hzy312/knowledge-r1)
|
| 234 |
+
- [Scent of Knowledge](https://arxiv.org/abs/2505.09316): Optimizing Search-Enhanced Reasoning with Information Foraging.
|
| 235 |
+
- [AutoRefine](https://www.arxiv.org/pdf/2505.11277): Search and Refine During Think. [![[code]](https://img.shields.io/github/stars/syr-cn/AutoRefine)](https://github.com/syr-cn/AutoRefine)
|
| 236 |
+
- [O^2-Searcher](https://arxiv.org/pdf/2505.16582): A Searching-based Agent Model for Open-Domain Open-Ended Question Answering. [![[code]](https://img.shields.io/github/stars/Acade-Mate/O2-Searcher)](https://github.com/Acade-Mate/O2-Searcher)
|
| 237 |
+
- [MaskSearch](https://arxiv.org/pdf/2505.20285): A Universal Pre-Training Framework to Enhance Agentic Search Capability. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/MaskSearch)](https://github.com/Alibaba-NLP/MaskSearch)
|
| 238 |
+
- [VRAG-RL](https://arxiv.org/abs/2505.22019): Vision-Perception-Based RAG for Visually Rich Information Understanding. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/VRAG)](https://github.com/Alibaba-NLP/VRAG)
|
| 239 |
+
- [R1-Code-Interpreter](https://arxiv.org/abs/2505.21668): Training LLMs to Reason with Code via SFT and RL. [![[code]](https://img.shields.io/github/stars/yongchao98/R1-Code-Interpreter)](https://github.com/yongchao98/R1-Code-Interpreter)
|
| 240 |
+
- [R-Search](https://arxiv.org/abs/2506.04185): Empowering LLM Reasoning with Search via Multi-Reward Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/QingFei1/R-Search)](https://github.com/QingFei1/R-Search)
|
| 241 |
+
- [StepSearch](https://arxiv.org/pdf/2505.15107): Igniting LLMs Search Ability via Step-Wise Proximal Policy Optimization. [![[code]](https://img.shields.io/github/stars/Zillwang/StepSearch)](https://github.com/Zillwang/StepSearch)
|
| 242 |
+
- [SimpleTIR](https://simpletir.notion.site/report): Stable End-to-End Reinforcement Learning for Multi-Turn Tool-Integrated Reasoning. [![[code]](https://img.shields.io/github/stars/ltzheng/SimpleTIR)](https://github.com/ltzheng/SimpleTIR)
|
| 243 |
+
- [Router-R1](https://arxiv.org/pdf/2506.09033): Teaching LLMs Multi-Round Routing and Aggregation via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/ulab-uiuc/Router-R1)](https://github.com/ulab-uiuc/Router-R1)
|
| 244 |
+
- [SkyRL](https://skyrl.readthedocs.io/en/latest/): A Modular Full-stack RL Library for LLMs. [![[code]](https://img.shields.io/github/stars/NovaSky-AI/SkyRL)](https://github.com/NovaSky-AI/SkyRL)
|
| 245 |
+
- [ASearcher](https://arxiv.org/abs/2508.07976): Large-Scale RL for Search Agents. [![[code]](https://img.shields.io/github/stars/inclusionAI/ASearcher)](https://github.com/inclusionAI/ASearcher)
|
| 246 |
+
- [ParallelSearch](https://www.arxiv.org/abs/2508.09303): Decompose Query and Search Sub-queries in Parallel with RL. [![[code]](https://img.shields.io/github/stars/Tree-Shu-Zhao/ParallelSearch)](https://github.com/Tree-Shu-Zhao/ParallelSearch)
|
| 247 |
+
- [AutoTIR](https://arxiv.org/pdf/2507.21836): Autonomous Tools Integrated Reasoning via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/weiyifan1023/AutoTIR)](https://github.com/weiyifan1023/AutoTIR)
|
| 248 |
+
- [verl-tool](https://arxiv.org/pdf/2509.01055): A version of verl to support diverse tool use. [![[code]](https://img.shields.io/github/stars/TIGER-AI-Lab/verl-tool)](https://github.com/TIGER-AI-Lab/verl-tool)
|
| 249 |
+
- [Tree-GRPO](https://arxiv.org/abs/2509.21240): Tree Search for LLM Agent Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/AMAP-ML/Tree-GRPO)](https://github.com/AMAP-ML/Tree-GRPO)
|
| 250 |
+
- [EviNote-RAG](https://arxiv.org/abs/2509.00877): Enhancing RAG Models via Answer-Supportive Evidence Notes. [![[code]](https://img.shields.io/github/stars/Da1yuqin/EviNoteRAG)](https://github.com/Da1yuqin/EviNoteRAG)
|
| 251 |
+
- [GlobalRAG](https://arxiv.org/pdf/2510.20548v1): GlobalRAG: Enhancing Global Reasoning in Multi-hop Question Answering via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/CarnegieBin/GlobalRAG)](https://github.com/CarnegieBin/GlobalRAG)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
## Citations
|
| 258 |
+
|
| 259 |
+
```bibtex
|
| 260 |
+
@article{jin2025search,
|
| 261 |
+
title={Search-r1: Training llms to reason and leverage search engines with reinforcement learning},
|
| 262 |
+
author={Jin, Bowen and Zeng, Hansi and Yue, Zhenrui and Yoon, Jinsung and Arik, Sercan and Wang, Dong and Zamani, Hamed and Han, Jiawei},
|
| 263 |
+
journal={arXiv preprint arXiv:2503.09516},
|
| 264 |
+
year={2025}
|
| 265 |
+
}
|
| 266 |
+
```
|
| 267 |
+
|
| 268 |
+
```bibtex
|
| 269 |
+
@article{jin2025empirical,
|
| 270 |
+
title={An Empirical Study on Reinforcement Learning for Reasoning-Search Interleaved LLM Agents},
|
| 271 |
+
author={Jin, Bowen and Yoon, Jinsung and Kargupta, Priyanka and Arik, Sercan O and Han, Jiawei},
|
| 272 |
+
journal={arXiv preprint arXiv:2505.15117},
|
| 273 |
+
year={2025}
|
| 274 |
+
}
|
| 275 |
+
```
|
code/RL_model/verl/Search-R1/VERL_README.md
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 style="text-align: center;">veRL: Volcano Engine Reinforcement Learning for LLM</h1>
|
| 2 |
+
|
| 3 |
+
veRL is a flexible, efficient and production-ready RL training framework designed for large language models (LLMs).
|
| 4 |
+
|
| 5 |
+
veRL is the open-source version of **[HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2)** paper.
|
| 6 |
+
|
| 7 |
+
veRL is flexible and easy to use with:
|
| 8 |
+
|
| 9 |
+
- **Easy extension of diverse RL algorithms**: The Hybrid programming model combines the strengths of single-controller and multi-controller paradigms to enable flexible representation and efficient execution of complex Post-Training dataflows. Allowing users to build RL dataflows in a few lines of code.
|
| 10 |
+
|
| 11 |
+
- **Seamless integration of existing LLM infra with modular APIs**: Decouples computation and data dependencies, enabling seamless integration with existing LLM frameworks, such as PyTorch FSDP, Megatron-LM and vLLM. Moreover, users can easily extend to other LLM training and inference frameworks.
|
| 12 |
+
|
| 13 |
+
- **Flexible device mapping**: Supports various placement of models onto different sets of GPUs for efficient resource utilization and scalability across different cluster sizes.
|
| 14 |
+
|
| 15 |
+
- Readily integration with popular HuggingFace models
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
veRL is fast with:
|
| 19 |
+
|
| 20 |
+
- **State-of-the-art throughput**: By seamlessly integrating existing SOTA LLM training and inference frameworks, veRL achieves high generation and training throughput.
|
| 21 |
+
|
| 22 |
+
- **Efficient actor model resharding with 3D-HybridEngine**: Eliminates memory redundancy and significantly reduces communication overhead during transitions between training and generation phases.
|
| 23 |
+
|
| 24 |
+
<p align="center">
|
| 25 |
+
| <a href="https://verl.readthedocs.io/en/latest/index.html"><b>Documentation</b></a> | <a href="https://arxiv.org/abs/2409.19256v2"><b>Paper</b></a> | <a href="https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA"><b>Slack</b></a> | <a href="https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/WeChat.JPG"><b>Wechat</b></a> |
|
| 26 |
+
|
| 27 |
+
<!-- <a href=""><b>Slides</b></a> | -->
|
| 28 |
+
</p>
|
| 29 |
+
|
| 30 |
+
## News
|
| 31 |
+
|
| 32 |
+
- [2024/12] The team presented <a href="https://neurips.cc/Expo/Conferences/2024/workshop/100677">Post-training LLMs: From Algorithms to Infrastructure</a> at NeurIPS 2024. [Slides](https://github.com/eric-haibin-lin/verl-data/tree/neurips) and [video](https://neurips.cc/Expo/Conferences/2024/workshop/100677) available.
|
| 33 |
+
- [2024/10] veRL is presented at Ray Summit. [Youtube video](https://www.youtube.com/watch?v=MrhMcXkXvJU&list=PLzTswPQNepXntmT8jr9WaNfqQ60QwW7-U&index=37) available.
|
| 34 |
+
- [2024/08] HybridFlow (verl) is accepted to EuroSys 2025.
|
| 35 |
+
|
| 36 |
+
## Key Features
|
| 37 |
+
|
| 38 |
+
- **FSDP** and **Megatron-LM** for training.
|
| 39 |
+
- **vLLM** and **TGI** for rollout generation, **SGLang** support coming soon.
|
| 40 |
+
- huggingface models support
|
| 41 |
+
- Supervised fine-tuning
|
| 42 |
+
- Reward model training
|
| 43 |
+
- Reinforcement learning from human feedback with PPO
|
| 44 |
+
- flash-attention integration, sequence packing
|
| 45 |
+
- scales up to 70B models and hundreds of GPUs
|
| 46 |
+
- experiment tracking with wandb and mlflow
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
## Getting Started
|
| 50 |
+
|
| 51 |
+
Checkout this [Jupyter Notebook](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer/verl_getting_started.ipynb) to get started with PPO training with a single 24GB L4 GPU (**FREE** GPU quota provided by [Lighting Studio](https://lightning.ai/hlin-verl/studios/verl-getting-started))!
|
| 52 |
+
|
| 53 |
+
**Quickstart:**
|
| 54 |
+
- [Installation](https://verl.readthedocs.io/en/latest/start/install.html)
|
| 55 |
+
- [Quickstart](https://verl.readthedocs.io/en/latest/start/quickstart.html)
|
| 56 |
+
|
| 57 |
+
**Running an PPO example step-by-step:**
|
| 58 |
+
- Data and Reward Preparation
|
| 59 |
+
- [Prepare Data (Parquet) for Post-Training](https://verl.readthedocs.io/en/latest/preparation/prepare_data.html)
|
| 60 |
+
- [Implement Reward Function for Dataset](https://verl.readthedocs.io/en/latest/preparation/reward_function.html)
|
| 61 |
+
- Understanding the PPO Example
|
| 62 |
+
- [PPO Example Architecture](https://verl.readthedocs.io/en/latest/examples/ppo_code_architecture.html)
|
| 63 |
+
- [Config Explanation](https://verl.readthedocs.io/en/latest/examples/config.html)
|
| 64 |
+
- [Run GSM8K Example](https://verl.readthedocs.io/en/latest/examples/gsm8k_example.html)
|
| 65 |
+
|
| 66 |
+
**Reproducible algorithm baselines:**
|
| 67 |
+
- [PPO](https://verl.readthedocs.io/en/latest/experiment/ppo.html)
|
| 68 |
+
|
| 69 |
+
**For code explanation and advance usage (extension):**
|
| 70 |
+
- PPO Trainer and Workers
|
| 71 |
+
- [PPO Ray Trainer](https://verl.readthedocs.io/en/latest/workers/ray_trainer.html)
|
| 72 |
+
- [PyTorch FSDP Backend](https://verl.readthedocs.io/en/latest/workers/fsdp_workers.html)
|
| 73 |
+
- [Megatron-LM Backend](https://verl.readthedocs.io/en/latest/index.html)
|
| 74 |
+
- Advance Usage and Extension
|
| 75 |
+
- [Ray API Design Tutorial](https://verl.readthedocs.io/en/latest/advance/placement.html)
|
| 76 |
+
- [Extend to other RL(HF) algorithms](https://verl.readthedocs.io/en/latest/advance/dpo_extension.html)
|
| 77 |
+
- [Add models with the FSDP backend](https://verl.readthedocs.io/en/latest/advance/fsdp_extension.html)
|
| 78 |
+
- [Add models with the Megatron-LM backend](https://verl.readthedocs.io/en/latest/advance/megatron_extension.html)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
## Citation and acknowledgement
|
| 82 |
+
|
| 83 |
+
If you find the project helpful, please cite:
|
| 84 |
+
- [HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2)
|
| 85 |
+
- [A Framework for Training Large Language Models for Code Generation via Proximal Policy Optimization](https://i.cs.hku.hk/~cwu/papers/gmsheng-NL2Code24.pdf)
|
| 86 |
+
|
| 87 |
+
```tex
|
| 88 |
+
@article{sheng2024hybridflow,
|
| 89 |
+
title = {HybridFlow: A Flexible and Efficient RLHF Framework},
|
| 90 |
+
author = {Guangming Sheng and Chi Zhang and Zilingfeng Ye and Xibin Wu and Wang Zhang and Ru Zhang and Yanghua Peng and Haibin Lin and Chuan Wu},
|
| 91 |
+
year = {2024},
|
| 92 |
+
journal = {arXiv preprint arXiv: 2409.19256}
|
| 93 |
+
}
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The project is adopted and supported by Anyscale, Bytedance, LMSys.org, Shanghai AI Lab, Tsinghua University, UC Berkeley, UCLA, UIUC, and University of Hong Kong.
|
| 97 |
+
|
| 98 |
+
## Publications Using veRL
|
| 99 |
+
- [Enhancing Multi-Step Reasoning Abilities of Language Models through Direct Q-Function Optimization](https://arxiv.org/abs/2410.09302)
|
| 100 |
+
- [Flaming-hot Initiation with Regular Execution Sampling for Large Language Models](https://arxiv.org/abs/2410.21236)
|
| 101 |
+
- [Process Reinforcement Through Implicit Rewards](https://github.com/PRIME-RL/PRIME/)
|
| 102 |
+
|
| 103 |
+
We are HIRING! Send us an [email](mailto:haibin.lin@bytedance.com) if you are interested in internship/FTE opportunities in MLSys/LLM reasoning/multimodal alignment.
|
code/RL_model/verl/Search-R1/infer.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import transformers
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
question = "Mike Barnett negotiated many contracts including which player that went on to become general manager of CSKA Moscow of the Kontinental Hockey League?"
|
| 8 |
+
|
| 9 |
+
# Model ID and device setup
|
| 10 |
+
model_id = "PeterJinGo/SearchR1-nq_hotpotqa_train-qwen2.5-7b-em-ppo"
|
| 11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
+
|
| 13 |
+
question = question.strip()
|
| 14 |
+
if question[-1] != '?':
|
| 15 |
+
question += '?'
|
| 16 |
+
curr_eos = [151645, 151643] # for Qwen2.5 series models
|
| 17 |
+
curr_search_template = '\n\n{output_text}<information>{search_results}</information>\n\n'
|
| 18 |
+
|
| 19 |
+
# Prepare the message
|
| 20 |
+
prompt = f"""Answer the given question. \
|
| 21 |
+
You must conduct reasoning inside <think> and </think> first every time you get new information. \
|
| 22 |
+
After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search> and it will return the top searched results between <information> and </information>. \
|
| 23 |
+
You can search as many times as your want. \
|
| 24 |
+
If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>. Question: {question}\n"""
|
| 25 |
+
|
| 26 |
+
# Initialize the tokenizer and model
|
| 27 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
|
| 28 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
|
| 29 |
+
|
| 30 |
+
# Define the custom stopping criterion
|
| 31 |
+
class StopOnSequence(transformers.StoppingCriteria):
|
| 32 |
+
def __init__(self, target_sequences, tokenizer):
|
| 33 |
+
# Encode the string so we have the exact token-IDs pattern
|
| 34 |
+
self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences]
|
| 35 |
+
self.target_lengths = [len(target_id) for target_id in self.target_ids]
|
| 36 |
+
self._tokenizer = tokenizer
|
| 37 |
+
|
| 38 |
+
def __call__(self, input_ids, scores, **kwargs):
|
| 39 |
+
# Make sure the target IDs are on the same device
|
| 40 |
+
targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids]
|
| 41 |
+
|
| 42 |
+
if input_ids.shape[1] < min(self.target_lengths):
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
# Compare the tail of input_ids with our target_ids
|
| 46 |
+
for i, target in enumerate(targets):
|
| 47 |
+
if torch.equal(input_ids[0, -self.target_lengths[i]:], target):
|
| 48 |
+
return True
|
| 49 |
+
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
def get_query(text):
|
| 53 |
+
import re
|
| 54 |
+
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
|
| 55 |
+
matches = pattern.findall(text)
|
| 56 |
+
if matches:
|
| 57 |
+
return matches[-1]
|
| 58 |
+
else:
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
def search(query: str):
|
| 62 |
+
payload = {
|
| 63 |
+
"queries": [query],
|
| 64 |
+
"topk": 3,
|
| 65 |
+
"return_scores": True
|
| 66 |
+
}
|
| 67 |
+
results = requests.post("http://127.0.0.1:8000/retrieve", json=payload).json()['result']
|
| 68 |
+
|
| 69 |
+
def _passages2string(retrieval_result):
|
| 70 |
+
format_reference = ''
|
| 71 |
+
for idx, doc_item in enumerate(retrieval_result):
|
| 72 |
+
|
| 73 |
+
content = doc_item['document']['contents']
|
| 74 |
+
title = content.split("\n")[0]
|
| 75 |
+
text = "\n".join(content.split("\n")[1:])
|
| 76 |
+
format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
|
| 77 |
+
return format_reference
|
| 78 |
+
|
| 79 |
+
return _passages2string(results[0])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Initialize the stopping criteria
|
| 83 |
+
target_sequences = ["</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"]
|
| 84 |
+
stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)])
|
| 85 |
+
|
| 86 |
+
cnt = 0
|
| 87 |
+
|
| 88 |
+
if tokenizer.chat_template:
|
| 89 |
+
prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False)
|
| 90 |
+
|
| 91 |
+
print('\n\n################# [Start Reasoning + Searching] ##################\n\n')
|
| 92 |
+
print(prompt)
|
| 93 |
+
# Encode the chat-formatted prompt and move it to the correct device
|
| 94 |
+
while True:
|
| 95 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
|
| 96 |
+
attention_mask = torch.ones_like(input_ids)
|
| 97 |
+
|
| 98 |
+
# Generate text with the stopping criteria
|
| 99 |
+
outputs = model.generate(
|
| 100 |
+
input_ids,
|
| 101 |
+
attention_mask=attention_mask,
|
| 102 |
+
max_new_tokens=1024,
|
| 103 |
+
stopping_criteria=stopping_criteria,
|
| 104 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 105 |
+
do_sample=True,
|
| 106 |
+
temperature=0.7
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if outputs[0][-1].item() in curr_eos:
|
| 110 |
+
generated_tokens = outputs[0][input_ids.shape[1]:]
|
| 111 |
+
output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 112 |
+
print(output_text)
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
generated_tokens = outputs[0][input_ids.shape[1]:]
|
| 116 |
+
output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 117 |
+
|
| 118 |
+
tmp_query = get_query(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 119 |
+
if tmp_query:
|
| 120 |
+
# print(f'searching "{tmp_query}"...')
|
| 121 |
+
search_results = search(tmp_query)
|
| 122 |
+
else:
|
| 123 |
+
search_results = ''
|
| 124 |
+
|
| 125 |
+
search_text = curr_search_template.format(output_text=output_text, search_results=search_results)
|
| 126 |
+
prompt += search_text
|
| 127 |
+
cnt += 1
|
| 128 |
+
print(search_text)
|
code/RL_model/verl/Search-R1/llm_guard_3B_10k_v2.log
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/home/mshahidul/miniconda3/envs/verl/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
|
| 2 |
+
import pynvml # type: ignore[import]
|
| 3 |
+
2026-02-01 20:43:15,317 INFO worker.py:2014 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8301 [39m[22m
|
| 4 |
+
/home/mshahidul/miniconda3/envs/verl/lib/python3.12/site-packages/ray/_private/worker.py:2062: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
|
| 5 |
+
warnings.warn(
|
| 6 |
+
[36m(pid=1646422)[0m /home/mshahidul/miniconda3/envs/verl/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
|
| 7 |
+
[36m(pid=1646422)[0m import pynvml # type: ignore[import]
|
| 8 |
+
[36m(main_task pid=1646422)[0m {'actor_rollout_ref': {'actor': {'clip_ratio': 0.2,
|
| 9 |
+
[36m(main_task pid=1646422)[0m 'entropy_coeff': 0.001,
|
| 10 |
+
[36m(main_task pid=1646422)[0m 'fsdp_config': {'fsdp_size': -1,
|
| 11 |
+
[36m(main_task pid=1646422)[0m 'grad_offload': False,
|
| 12 |
+
[36m(main_task pid=1646422)[0m 'optimizer_offload': True,
|
| 13 |
+
[36m(main_task pid=1646422)[0m 'param_offload': True,
|
| 14 |
+
[36m(main_task pid=1646422)[0m 'wrap_policy': {'min_num_params': 0}},
|
| 15 |
+
[36m(main_task pid=1646422)[0m 'grad_clip': 1.0,
|
| 16 |
+
[36m(main_task pid=1646422)[0m 'kl_loss_coef': 0.001,
|
| 17 |
+
[36m(main_task pid=1646422)[0m 'kl_loss_type': 'low_var_kl',
|
| 18 |
+
[36m(main_task pid=1646422)[0m 'optim': {'lr': 1e-06,
|
| 19 |
+
[36m(main_task pid=1646422)[0m 'lr_warmup_steps_ratio': 0.0,
|
| 20 |
+
[36m(main_task pid=1646422)[0m 'min_lr_ratio': None,
|
| 21 |
+
[36m(main_task pid=1646422)[0m 'total_training_steps': -1,
|
| 22 |
+
[36m(main_task pid=1646422)[0m 'warmup_style': 'constant'},
|
| 23 |
+
[36m(main_task pid=1646422)[0m 'ppo_epochs': 1,
|
| 24 |
+
[36m(main_task pid=1646422)[0m 'ppo_max_token_len_per_gpu': 16384,
|
| 25 |
+
[36m(main_task pid=1646422)[0m 'ppo_micro_batch_size': 64,
|
| 26 |
+
[36m(main_task pid=1646422)[0m 'ppo_micro_batch_size_per_gpu': 16,
|
| 27 |
+
[36m(main_task pid=1646422)[0m 'ppo_mini_batch_size': 64,
|
| 28 |
+
[36m(main_task pid=1646422)[0m 'shuffle': False,
|
| 29 |
+
[36m(main_task pid=1646422)[0m 'state_masking': False,
|
| 30 |
+
[36m(main_task pid=1646422)[0m 'strategy': 'fsdp',
|
| 31 |
+
[36m(main_task pid=1646422)[0m 'ulysses_sequence_parallel_size': 1,
|
| 32 |
+
[36m(main_task pid=1646422)[0m 'use_dynamic_bsz': False,
|
| 33 |
+
[36m(main_task pid=1646422)[0m 'use_kl_loss': False},
|
| 34 |
+
[36m(main_task pid=1646422)[0m 'hybrid_engine': True,
|
| 35 |
+
[36m(main_task pid=1646422)[0m 'model': {'enable_gradient_checkpointing': True,
|
| 36 |
+
[36m(main_task pid=1646422)[0m 'external_lib': None,
|
| 37 |
+
[36m(main_task pid=1646422)[0m 'override_config': {},
|
| 38 |
+
[36m(main_task pid=1646422)[0m 'path': 'Qwen/Qwen3-4B-Instruct-2507',
|
| 39 |
+
[36m(main_task pid=1646422)[0m 'use_remove_padding': False},
|
| 40 |
+
[36m(main_task pid=1646422)[0m 'ref': {'fsdp_config': {'fsdp_size': -1,
|
| 41 |
+
[36m(main_task pid=1646422)[0m 'param_offload': True,
|
| 42 |
+
[36m(main_task pid=1646422)[0m 'wrap_policy': {'min_num_params': 0}},
|
| 43 |
+
[36m(main_task pid=1646422)[0m 'log_prob_max_token_len_per_gpu': 16384,
|
| 44 |
+
[36m(main_task pid=1646422)[0m 'log_prob_micro_batch_size': 64,
|
| 45 |
+
[36m(main_task pid=1646422)[0m 'log_prob_use_dynamic_bsz': False,
|
| 46 |
+
[36m(main_task pid=1646422)[0m 'ulysses_sequence_parallel_size': 1},
|
| 47 |
+
[36m(main_task pid=1646422)[0m 'rollout': {'do_sample': True,
|
| 48 |
+
[36m(main_task pid=1646422)[0m 'dtype': 'bfloat16',
|
| 49 |
+
[36m(main_task pid=1646422)[0m 'enforce_eager': True,
|
| 50 |
+
[36m(main_task pid=1646422)[0m 'free_cache_engine': True,
|
| 51 |
+
[36m(main_task pid=1646422)[0m 'gpu_memory_utilization': 0.4,
|
| 52 |
+
[36m(main_task pid=1646422)[0m 'ignore_eos': False,
|
| 53 |
+
[36m(main_task pid=1646422)[0m 'load_format': 'dummy_dtensor',
|
| 54 |
+
[36m(main_task pid=1646422)[0m 'log_prob_max_token_len_per_gpu': 16384,
|
| 55 |
+
[36m(main_task pid=1646422)[0m 'log_prob_micro_batch_size': 64,
|
| 56 |
+
[36m(main_task pid=1646422)[0m 'log_prob_use_dynamic_bsz': False,
|
| 57 |
+
[36m(main_task pid=1646422)[0m 'max_num_batched_tokens': 8192,
|
| 58 |
+
[36m(main_task pid=1646422)[0m 'max_num_seqs': 1024,
|
| 59 |
+
[36m(main_task pid=1646422)[0m 'n': 1,
|
| 60 |
+
[36m(main_task pid=1646422)[0m 'n_agent': 1,
|
| 61 |
+
[36m(main_task pid=1646422)[0m 'name': 'vllm',
|
| 62 |
+
[36m(main_task pid=1646422)[0m 'prompt_length': 4096,
|
| 63 |
+
[36m(main_task pid=1646422)[0m 'response_length': 1024,
|
| 64 |
+
[36m(main_task pid=1646422)[0m 'temperature': 1.0,
|
| 65 |
+
[36m(main_task pid=1646422)[0m 'tensor_model_parallel_size': 1,
|
| 66 |
+
[36m(main_task pid=1646422)[0m 'top_k': -1,
|
| 67 |
+
[36m(main_task pid=1646422)[0m 'top_p': 0.95}},
|
| 68 |
+
[36m(main_task pid=1646422)[0m 'algorithm': {'adv_estimator': 'grpo',
|
| 69 |
+
[36m(main_task pid=1646422)[0m 'gamma': 1.0,
|
| 70 |
+
[36m(main_task pid=1646422)[0m 'kl_ctrl': {'kl_coef': 0.001, 'type': 'fixed'},
|
| 71 |
+
[36m(main_task pid=1646422)[0m 'kl_penalty': 'kl',
|
| 72 |
+
[36m(main_task pid=1646422)[0m 'lam': 1.0,
|
| 73 |
+
[36m(main_task pid=1646422)[0m 'no_think_rl': False,
|
| 74 |
+
[36m(main_task pid=1646422)[0m 'state_masking': {'end_state_marker': '</information>',
|
| 75 |
+
[36m(main_task pid=1646422)[0m 'start_state_marker': '<information>'}},
|
| 76 |
+
[36m(main_task pid=1646422)[0m 'critic': {'cliprange_value': 0.5,
|
| 77 |
+
[36m(main_task pid=1646422)[0m 'forward_max_token_len_per_gpu': 32768,
|
| 78 |
+
[36m(main_task pid=1646422)[0m 'forward_micro_batch_size': 64,
|
| 79 |
+
[36m(main_task pid=1646422)[0m 'grad_clip': 1.0,
|
| 80 |
+
[36m(main_task pid=1646422)[0m 'model': {'enable_gradient_checkpointing': False,
|
| 81 |
+
[36m(main_task pid=1646422)[0m 'external_lib': None,
|
| 82 |
+
[36m(main_task pid=1646422)[0m 'fsdp_config': {'fsdp_size': -1,
|
| 83 |
+
[36m(main_task pid=1646422)[0m 'grad_offload': False,
|
| 84 |
+
[36m(main_task pid=1646422)[0m 'optimizer_offload': False,
|
| 85 |
+
[36m(main_task pid=1646422)[0m 'param_offload': False,
|
| 86 |
+
[36m(main_task pid=1646422)[0m 'wrap_policy': {'min_num_params': 0}},
|
| 87 |
+
[36m(main_task pid=1646422)[0m 'override_config': {},
|
| 88 |
+
[36m(main_task pid=1646422)[0m 'path': '~/models/deepseek-llm-7b-chat',
|
| 89 |
+
[36m(main_task pid=1646422)[0m 'tokenizer_path': 'Qwen/Qwen3-4B-Instruct-2507',
|
| 90 |
+
[36m(main_task pid=1646422)[0m 'use_remove_padding': False},
|
| 91 |
+
[36m(main_task pid=1646422)[0m 'optim': {'lr': 1e-05,
|
| 92 |
+
[36m(main_task pid=1646422)[0m 'lr_warmup_steps_ratio': 0.0,
|
| 93 |
+
[36m(main_task pid=1646422)[0m 'min_lr_ratio': None,
|
| 94 |
+
[36m(main_task pid=1646422)[0m 'total_training_steps': -1,
|
| 95 |
+
[36m(main_task pid=1646422)[0m 'warmup_style': 'constant'},
|
| 96 |
+
[36m(main_task pid=1646422)[0m 'ppo_epochs': 1,
|
| 97 |
+
[36m(main_task pid=1646422)[0m 'ppo_max_token_len_per_gpu': 32768,
|
| 98 |
+
[36m(main_task pid=1646422)[0m 'ppo_micro_batch_size': 64,
|
| 99 |
+
[36m(main_task pid=1646422)[0m 'ppo_mini_batch_size': 64,
|
| 100 |
+
[36m(main_task pid=1646422)[0m 'shuffle': False,
|
| 101 |
+
[36m(main_task pid=1646422)[0m 'strategy': 'fsdp',
|
| 102 |
+
[36m(main_task pid=1646422)[0m 'ulysses_sequence_parallel_size': 1,
|
| 103 |
+
[36m(main_task pid=1646422)[0m 'use_dynamic_bsz': False},
|
| 104 |
+
[36m(main_task pid=1646422)[0m 'data': {'max_obs_length': 512,
|
| 105 |
+
[36m(main_task pid=1646422)[0m 'max_prompt_length': 4096,
|
| 106 |
+
[36m(main_task pid=1646422)[0m 'max_response_length': 1024,
|
| 107 |
+
[36m(main_task pid=1646422)[0m 'max_start_length': 256,
|
| 108 |
+
[36m(main_task pid=1646422)[0m 'prompt_key': 'prompt',
|
| 109 |
+
[36m(main_task pid=1646422)[0m 'return_raw_chat': False,
|
| 110 |
+
[36m(main_task pid=1646422)[0m 'return_raw_input_ids': False,
|
| 111 |
+
[36m(main_task pid=1646422)[0m 'shuffle_train_dataloader': True,
|
| 112 |
+
[36m(main_task pid=1646422)[0m 'tokenizer': None,
|
| 113 |
+
[36m(main_task pid=1646422)[0m 'train_batch_size': 128,
|
| 114 |
+
[36m(main_task pid=1646422)[0m 'train_data_num': None,
|
| 115 |
+
[36m(main_task pid=1646422)[0m 'train_files': '/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet',
|
| 116 |
+
[36m(main_task pid=1646422)[0m 'val_batch_size': 64,
|
| 117 |
+
[36m(main_task pid=1646422)[0m 'val_data_num': None,
|
| 118 |
+
[36m(main_task pid=1646422)[0m 'val_files': '/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet'},
|
| 119 |
+
[36m(main_task pid=1646422)[0m 'do_search': False,
|
| 120 |
+
[36m(main_task pid=1646422)[0m 'max_turns': 1,
|
| 121 |
+
[36m(main_task pid=1646422)[0m 'retriever': {'topk': 3, 'url': 'http://127.0.0.1:8000/retrieve'},
|
| 122 |
+
[36m(main_task pid=1646422)[0m 'reward_model': {'enable': False,
|
| 123 |
+
[36m(main_task pid=1646422)[0m 'final_format_score': 0,
|
| 124 |
+
[36m(main_task pid=1646422)[0m 'forward_max_token_len_per_gpu': 32768,
|
| 125 |
+
[36m(main_task pid=1646422)[0m 'max_length': None,
|
| 126 |
+
[36m(main_task pid=1646422)[0m 'micro_batch_size': 64,
|
| 127 |
+
[36m(main_task pid=1646422)[0m 'model': {'external_lib': None,
|
| 128 |
+
[36m(main_task pid=1646422)[0m 'fsdp_config': {'min_num_params': 0,
|
| 129 |
+
[36m(main_task pid=1646422)[0m 'param_offload': False},
|
| 130 |
+
[36m(main_task pid=1646422)[0m 'input_tokenizer': 'Qwen/Qwen3-4B-Instruct-2507',
|
| 131 |
+
[36m(main_task pid=1646422)[0m 'path': '~/models/FsfairX-LLaMA3-RM-v0.1',
|
| 132 |
+
[36m(main_task pid=1646422)[0m 'use_remove_padding': False},
|
| 133 |
+
[36m(main_task pid=1646422)[0m 'retrieval_score': 0,
|
| 134 |
+
[36m(main_task pid=1646422)[0m 'strategy': 'fsdp',
|
| 135 |
+
[36m(main_task pid=1646422)[0m 'structure_format_score': 0,
|
| 136 |
+
[36m(main_task pid=1646422)[0m 'ulysses_sequence_parallel_size': 1,
|
| 137 |
+
[36m(main_task pid=1646422)[0m 'use_dynamic_bsz': False},
|
| 138 |
+
[36m(main_task pid=1646422)[0m 'trainer': {'critic_warmup': 0,
|
| 139 |
+
[36m(main_task pid=1646422)[0m 'default_hdfs_dir': '~/experiments/gsm8k/ppo/llm_guard_3B_10k_v2',
|
| 140 |
+
[36m(main_task pid=1646422)[0m 'default_local_dir': 'verl_checkpoints/llm_guard_3B_10k_v2',
|
| 141 |
+
[36m(main_task pid=1646422)[0m 'experiment_name': 'llm_guard_3B_10k_v2',
|
| 142 |
+
[36m(main_task pid=1646422)[0m 'logger': ['wandb'],
|
| 143 |
+
[36m(main_task pid=1646422)[0m 'n_gpus_per_node': 2,
|
| 144 |
+
[36m(main_task pid=1646422)[0m 'nnodes': 1,
|
| 145 |
+
[36m(main_task pid=1646422)[0m 'project_name': '',
|
| 146 |
+
[36m(main_task pid=1646422)[0m 'save_freq': 100,
|
| 147 |
+
[36m(main_task pid=1646422)[0m 'test_freq': 50,
|
| 148 |
+
[36m(main_task pid=1646422)[0m 'total_epochs': 15,
|
| 149 |
+
[36m(main_task pid=1646422)[0m 'total_training_steps': 1005}}
|
| 150 |
+
[36m(main_task pid=1646422)[0m W0201 20:43:46.380000 1646422 /data/home_beta/mshahidul/miniconda3/envs/verl/lib/python3.12/site-packages/torch/utils/cpp_extension.py:117] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
|
| 151 |
+
Error executing job with overrides: ['data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet', 'data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet', 'data.train_batch_size=128', 'data.val_batch_size=64', 'data.max_prompt_length=4096', 'data.max_response_length=1024', 'data.shuffle_train_dataloader=True', 'algorithm.adv_estimator=grpo', 'actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507', 'actor_rollout_ref.model.enable_gradient_checkpointing=true', 'actor_rollout_ref.model.use_remove_padding=False', 'actor_rollout_ref.actor.optim.lr=1e-6', 'actor_rollout_ref.actor.ppo_mini_batch_size=64', '+actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16', 'actor_rollout_ref.actor.fsdp_config.param_offload=true', 'actor_rollout_ref.actor.fsdp_config.optimizer_offload=true', 'actor_rollout_ref.rollout.log_prob_micro_batch_size=64', 'actor_rollout_ref.rollout.tensor_model_parallel_size=1', 'actor_rollout_ref.rollout.name=vllm', 'actor_rollout_ref.rollout.gpu_memory_utilization=0.4', 'actor_rollout_ref.ref.log_prob_micro_batch_size=64', 'actor_rollout_ref.ref.fsdp_config.param_offload=True', 'actor_rollout_ref.actor.kl_loss_coef=0.001', 'trainer.logger=[wandb]', 'trainer.n_gpus_per_node=2', 'trainer.nnodes=1', 'trainer.save_freq=100', 'trainer.test_freq=50', 'trainer.project_name=', 'trainer.experiment_name=llm_guard_3B_10k_v2', 'trainer.total_epochs=15', 'trainer.total_training_steps=1005', 'trainer.default_local_dir=verl_checkpoints/llm_guard_3B_10k_v2', 'do_search=false', 'max_turns=1']
|
| 152 |
+
Traceback (most recent call last):
|
| 153 |
+
File "/data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/trainer/main_ppo.py", line 110, in main
|
| 154 |
+
ray.get(main_task.remote(config))
|
| 155 |
+
File "/home/mshahidul/miniconda3/envs/verl/lib/python3.12/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
|
| 156 |
+
return fn(*args, **kwargs)
|
| 157 |
+
^^^^^^^^^^^^^^^^^^^
|
| 158 |
+
File "/home/mshahidul/miniconda3/envs/verl/lib/python3.12/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
|
| 159 |
+
return func(*args, **kwargs)
|
| 160 |
+
^^^^^^^^^^^^^^^^^^^^^
|
| 161 |
+
File "/home/mshahidul/miniconda3/envs/verl/lib/python3.12/site-packages/ray/_private/worker.py", line 2972, in get
|
| 162 |
+
values, debugger_breakpoint = worker.get_objects(
|
| 163 |
+
^^^^^^^^^^^^^^^^^^^
|
| 164 |
+
File "/home/mshahidul/miniconda3/envs/verl/lib/python3.12/site-packages/ray/_private/worker.py", line 1031, in get_objects
|
| 165 |
+
raise value.as_instanceof_cause()
|
| 166 |
+
ray.exceptions.RayTaskError(ImportError): [36mray::main_task()[39m (pid=1646422, ip=172.16.34.29)
|
| 167 |
+
File "/data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/trainer/main_ppo.py", line 136, in main_task
|
| 168 |
+
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
|
| 169 |
+
File "/data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/workers/fsdp_workers.py", line 39, in <module>
|
| 170 |
+
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
|
| 171 |
+
File "/data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/workers/sharding_manager/__init__.py", line 23, in <module>
|
| 172 |
+
from .megatron_vllm import AllGatherPPModel, MegatronVLLMShardingManager
|
| 173 |
+
File "/data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/workers/sharding_manager/megatron_vllm.py", line 230, in <module>
|
| 174 |
+
from verl.third_party.vllm import parallel_state as vllm_ps
|
| 175 |
+
File "/data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/verl/third_party/vllm/__init__.py", line 52, in <module>
|
| 176 |
+
from vllm import LLM, LLMEngine, parallel_state
|
| 177 |
+
ImportError: cannot import name 'parallel_state' from 'vllm' (/home/mshahidul/miniconda3/envs/verl/lib/python3.12/site-packages/vllm/__init__.py)
|
| 178 |
+
|
| 179 |
+
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
|
| 180 |
+
[W201 20:43:50.186179538 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
|
code/RL_model/verl/Search-R1/pyproject.toml
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -------------------------------
|
| 2 |
+
# build-system
|
| 3 |
+
# -------------------------------
|
| 4 |
+
[build-system]
|
| 5 |
+
requires = [
|
| 6 |
+
"setuptools>=61.0",
|
| 7 |
+
"wheel"
|
| 8 |
+
]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
# -------------------------------
|
| 12 |
+
# project (PEP 621 metadata)
|
| 13 |
+
# -------------------------------
|
| 14 |
+
[project]
|
| 15 |
+
name = "verl"
|
| 16 |
+
# We'll mark the version as "dynamic" because it's read from the file "verl/version/version"
|
| 17 |
+
# (PEP 621 calls this "dynamic version").
|
| 18 |
+
# The actual version is specified in the [tool.setuptools.dynamic] section below.
|
| 19 |
+
dynamic = ["version"]
|
| 20 |
+
|
| 21 |
+
description = "veRL: Volcano Engine Reinforcement Learning for LLM"
|
| 22 |
+
license = {file = "LICENSE"} # or "Apache-2.0", if you prefer an SPDX identifier
|
| 23 |
+
readme = {file = "README.md", content-type = "text/markdown"}
|
| 24 |
+
requires-python = ">=3.8"
|
| 25 |
+
|
| 26 |
+
authors = [
|
| 27 |
+
{ name = "Bytedance - Seed - MLSys", email = "zhangchi.usc1992@bytedance.com" },
|
| 28 |
+
{ name = "Bytedance - Seed - MLSys", email = "gmsheng@connect.hku.hk" },
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
# Dependencies corresponding to install_requires in setup.py
|
| 32 |
+
dependencies = [
|
| 33 |
+
"accelerate",
|
| 34 |
+
"codetiming",
|
| 35 |
+
"datasets",
|
| 36 |
+
"dill",
|
| 37 |
+
"hydra-core",
|
| 38 |
+
"numpy",
|
| 39 |
+
"pybind11",
|
| 40 |
+
"ray",
|
| 41 |
+
"tensordict",
|
| 42 |
+
"transformers<4.48",
|
| 43 |
+
"vllm<=0.6.3",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
# Optional dependencies (extras_require in setup.py)
|
| 47 |
+
[project.optional-dependencies]
|
| 48 |
+
test = [
|
| 49 |
+
"pytest", "yapf"
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
# URLs
|
| 53 |
+
[project.urls]
|
| 54 |
+
Homepage = "https://github.com/volcengine/verl"
|
| 55 |
+
|
| 56 |
+
# -------------------------------
|
| 57 |
+
# tool.setuptools - Additional config
|
| 58 |
+
# -------------------------------
|
| 59 |
+
[tool.setuptools]
|
| 60 |
+
# True means `setuptools` will attempt to include all relevant files in package_data automatically.
|
| 61 |
+
# This corresponds to `include_package_data=True` in setup.py.
|
| 62 |
+
include-package-data = true
|
| 63 |
+
|
| 64 |
+
# We read the version from a file in 'verl/version/version'
|
| 65 |
+
[tool.setuptools.dynamic]
|
| 66 |
+
version = {file = "verl/version/version"}
|
| 67 |
+
|
| 68 |
+
# If you need to mimic `package_dir={'': '.'}`:
|
| 69 |
+
[tool.setuptools.package-dir]
|
| 70 |
+
"" = "."
|
| 71 |
+
|
| 72 |
+
# If you need to include specific non-Python data (like YAML files or version file):
|
| 73 |
+
# This is the rough equivalent of package_data={'': ['version/*'], 'verl': ['trainer/config/*.yaml']}
|
| 74 |
+
[tool.setuptools.package-data]
|
| 75 |
+
verl = [
|
| 76 |
+
"version/*",
|
| 77 |
+
"trainer/config/*.yaml"
|
| 78 |
+
]
|
code/RL_model/verl/Search-R1/requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
codetiming
|
| 3 |
+
datasets
|
| 4 |
+
dill
|
| 5 |
+
flash-attn
|
| 6 |
+
hydra-core
|
| 7 |
+
numpy
|
| 8 |
+
pandas
|
| 9 |
+
pybind11
|
| 10 |
+
ray
|
| 11 |
+
tensordict<0.6
|
| 12 |
+
transformers<4.48
|
| 13 |
+
vllm<=0.6.3
|
| 14 |
+
wandb
|
| 15 |
+
IPython
|
| 16 |
+
matplotlib
|
code/RL_model/verl/Search-R1/retrieval_launch.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
file_path=/the/path/you/save/corpus
|
| 3 |
+
index_file=$file_path/e5_Flat.index
|
| 4 |
+
corpus_file=$file_path/wiki-18.jsonl
|
| 5 |
+
retriever_name=e5
|
| 6 |
+
retriever_path=intfloat/e5-base-v2
|
| 7 |
+
|
| 8 |
+
python search_r1/search/retrieval_server.py --index_path $index_file \
|
| 9 |
+
--corpus_path $corpus_file \
|
| 10 |
+
--topk 3 \
|
| 11 |
+
--retriever_name $retriever_name \
|
| 12 |
+
--retriever_model $retriever_path \
|
| 13 |
+
--faiss_gpu
|
code/RL_model/verl/Search-R1/setup.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# setup.py is the fallback installation script when pyproject.toml does not work
|
| 16 |
+
from setuptools import setup, find_packages
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
|
| 20 |
+
|
| 21 |
+
with open(os.path.join(version_folder, 'verl/version/version')) as f:
|
| 22 |
+
__version__ = f.read().strip()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
with open('requirements.txt') as f:
|
| 26 |
+
required = f.read().splitlines()
|
| 27 |
+
install_requires = [item.strip() for item in required if item.strip()[0] != '#']
|
| 28 |
+
|
| 29 |
+
extras_require = {
|
| 30 |
+
'test': ['pytest', 'yapf']
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
this_directory = Path(__file__).parent
|
| 35 |
+
long_description = (this_directory / "README.md").read_text()
|
| 36 |
+
|
| 37 |
+
setup(
|
| 38 |
+
name='verl',
|
| 39 |
+
version=__version__,
|
| 40 |
+
package_dir={'': '.'},
|
| 41 |
+
packages=find_packages(where='.'),
|
| 42 |
+
url='https://github.com/volcengine/verl',
|
| 43 |
+
license='Apache 2.0',
|
| 44 |
+
author='Bytedance - Seed - MLSys',
|
| 45 |
+
author_email='zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk',
|
| 46 |
+
description='veRL: Volcano Engine Reinforcement Learning for LLM',
|
| 47 |
+
install_requires=install_requires,
|
| 48 |
+
extras_require=extras_require,
|
| 49 |
+
package_data={'': ['version/*'],
|
| 50 |
+
'verl': ['trainer/config/*.yaml'],},
|
| 51 |
+
include_package_data=True,
|
| 52 |
+
long_description=long_description,
|
| 53 |
+
long_description_content_type='text/markdown'
|
| 54 |
+
)
|
code/RL_model/verl/Search-R1/train_grpo.sh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
export PYTORCH_CUDA_ALLOC_CONF=""
|
| 3 |
+
export EXPERIMENT_NAME=llm_guard_3B_10k_v2
|
| 4 |
+
export WAND_PROJECT='guard'
|
| 5 |
+
export CUDA_DEVICE_ORDER="PCI_BUS_ID"
|
| 6 |
+
export CUDA_VISIBLE_DEVICES=1,2
|
| 7 |
+
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
PYTHONUNBUFFERED=1 NCCL_P2P_DISABLE=1 NCCL_IB_DISABLE=1 python3 -m verl.trainer.main_ppo \
|
| 11 |
+
data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet \
|
| 12 |
+
data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet \
|
| 13 |
+
data.train_batch_size=64 \
|
| 14 |
+
data.val_batch_size=64 \
|
| 15 |
+
data.max_prompt_length=4096 \
|
| 16 |
+
data.max_response_length=1024 \
|
| 17 |
+
data.shuffle_train_dataloader=True \
|
| 18 |
+
algorithm.adv_estimator=grpo \
|
| 19 |
+
actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507 \
|
| 20 |
+
actor_rollout_ref.model.enable_gradient_checkpointing=true \
|
| 21 |
+
actor_rollout_ref.model.use_remove_padding=False \
|
| 22 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 23 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
|
| 24 |
+
+actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
|
| 25 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=true \
|
| 26 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=true \
|
| 27 |
+
actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \
|
| 28 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
| 29 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 30 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
|
| 31 |
+
actor_rollout_ref.ref.log_prob_micro_batch_size=64 \
|
| 32 |
+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
| 33 |
+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
| 34 |
+
trainer.logger=['wandb'] \
|
| 35 |
+
trainer.n_gpus_per_node=2 \
|
| 36 |
+
trainer.nnodes=1 \
|
| 37 |
+
trainer.save_freq=100 \
|
| 38 |
+
trainer.test_freq=50 \
|
| 39 |
+
trainer.project_name=$WANDB_PROJECT \
|
| 40 |
+
trainer.experiment_name=$EXPERIMENT_NAME \
|
| 41 |
+
trainer.total_epochs=15 \
|
| 42 |
+
trainer.total_training_steps=1005 \
|
| 43 |
+
trainer.default_local_dir=verl_checkpoints/$EXPERIMENT_NAME \
|
| 44 |
+
do_search=false \
|
| 45 |
+
max_turns=1 \
|
| 46 |
+
2>&1 | tee $EXPERIMENT_NAME.log
|
code/RL_model/verl/Search-R1/train_ppo.sh
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
| 2 |
+
export DATA_DIR='data/nq_search'
|
| 3 |
+
|
| 4 |
+
WAND_PROJECT='Search-R1'
|
| 5 |
+
|
| 6 |
+
export BASE_MODEL='meta-llama/Llama-3.2-3B'
|
| 7 |
+
export EXPERIMENT_NAME=nq-search-r1-ppo-llama3.2-3b-em
|
| 8 |
+
# export BASE_MODEL='meta-llama/Llama-3.2-3B-Instruct'
|
| 9 |
+
# export EXPERIMENT_NAME=nq-search-r1-ppo-llama3.2-3b-it-em
|
| 10 |
+
# export BASE_MODEL='meta-llama/Llama-3.1-8B'
|
| 11 |
+
# export EXPERIMENT_NAME=nq-search-r1-ppo-llama3.1-8b-em
|
| 12 |
+
# export BASE_MODEL='meta-llama/Llama-3.1-8B-Instruct'
|
| 13 |
+
# export EXPERIMENT_NAME=nq-search-r1-ppo-llama3.1-8b-it-em
|
| 14 |
+
|
| 15 |
+
# export BASE_MODEL='Qwen/Qwen2.5-3B'
|
| 16 |
+
# export EXPERIMENT_NAME=nq-search-r1-ppo-qwen2.5-3b-em
|
| 17 |
+
# export BASE_MODEL='Qwen/Qwen2.5-3B-Instruct'
|
| 18 |
+
# export EXPERIMENT_NAME=nq-search-r1-ppo-qwen2.5-3b-it-em
|
| 19 |
+
# export BASE_MODEL='Qwen/Qwen2.5-7B'
|
| 20 |
+
# export EXPERIMENT_NAME=nq-search-r1-ppo-qwen2.5-7b-em
|
| 21 |
+
# export BASE_MODEL='Qwen/Qwen2.5-7B-Instruct'
|
| 22 |
+
# export EXPERIMENT_NAME=nq-search-r1-ppo-qwen2.5-7b-it-em
|
| 23 |
+
|
| 24 |
+
# set -x
|
| 25 |
+
export VLLM_ATTENTION_BACKEND=XFORMERS # vllm + qwen2-7b with flash_attn has some issues
|
| 26 |
+
|
| 27 |
+
# max_prompt_length = (config['training']['max_start_length'] + config['training']['max_response_length'] * (config['training']['max_turns'] - 1) + config['training']['max_obs_length'] * config['training']['max_turns'])
|
| 28 |
+
|
| 29 |
+
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
|
| 30 |
+
data.train_files=$DATA_DIR/train.parquet \
|
| 31 |
+
data.val_files=$DATA_DIR/test.parquet \
|
| 32 |
+
data.train_data_num=null \
|
| 33 |
+
data.val_data_num=null \
|
| 34 |
+
data.train_batch_size=512 \
|
| 35 |
+
data.val_batch_size=256 \
|
| 36 |
+
data.max_prompt_length=4096 \
|
| 37 |
+
data.max_response_length=500 \
|
| 38 |
+
data.max_start_length=2048 \
|
| 39 |
+
data.max_obs_length=500 \
|
| 40 |
+
data.shuffle_train_dataloader=True \
|
| 41 |
+
algorithm.adv_estimator=gae \
|
| 42 |
+
actor_rollout_ref.model.path=$BASE_MODEL \
|
| 43 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 44 |
+
actor_rollout_ref.model.enable_gradient_checkpointing=true \
|
| 45 |
+
actor_rollout_ref.model.use_remove_padding=True \
|
| 46 |
+
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \
|
| 47 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
| 48 |
+
actor_rollout_ref.actor.ppo_micro_batch_size=64 \
|
| 49 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=true \
|
| 50 |
+
actor_rollout_ref.actor.fsdp_config.grad_offload=true \
|
| 51 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=true \
|
| 52 |
+
actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
|
| 53 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
| 54 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 55 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
| 56 |
+
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
|
| 57 |
+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
| 58 |
+
actor_rollout_ref.rollout.n_agent=1 \
|
| 59 |
+
actor_rollout_ref.rollout.temperature=1 \
|
| 60 |
+
actor_rollout_ref.actor.state_masking=true \
|
| 61 |
+
critic.optim.lr=1e-5 \
|
| 62 |
+
critic.model.use_remove_padding=True \
|
| 63 |
+
critic.optim.lr_warmup_steps_ratio=0.015 \
|
| 64 |
+
critic.model.path=$BASE_MODEL \
|
| 65 |
+
critic.model.enable_gradient_checkpointing=true \
|
| 66 |
+
critic.ppo_micro_batch_size=8 \
|
| 67 |
+
critic.model.fsdp_config.param_offload=true \
|
| 68 |
+
critic.model.fsdp_config.grad_offload=true \
|
| 69 |
+
critic.model.fsdp_config.optimizer_offload=true \
|
| 70 |
+
algorithm.kl_ctrl.kl_coef=0.001 \
|
| 71 |
+
algorithm.no_think_rl=false \
|
| 72 |
+
trainer.critic_warmup=0 \
|
| 73 |
+
trainer.logger=['wandb'] \
|
| 74 |
+
+trainer.val_only=false \
|
| 75 |
+
+trainer.val_before_train=true \
|
| 76 |
+
trainer.default_hdfs_dir=null \
|
| 77 |
+
trainer.n_gpus_per_node=8 \
|
| 78 |
+
trainer.nnodes=1 \
|
| 79 |
+
trainer.save_freq=100 \
|
| 80 |
+
trainer.test_freq=50 \
|
| 81 |
+
trainer.project_name=$WAND_PROJECT \
|
| 82 |
+
trainer.experiment_name=$EXPERIMENT_NAME \
|
| 83 |
+
trainer.total_epochs=15 \
|
| 84 |
+
trainer.total_training_steps=1005 \
|
| 85 |
+
trainer.default_hdfs_dir=null \
|
| 86 |
+
trainer.default_local_dir=verl_checkpoints/$EXPERIMENT_NAME \
|
| 87 |
+
max_turns=2 \
|
| 88 |
+
retriever.url="http://127.0.0.1:8000/retrieve" \
|
| 89 |
+
retriever.topk=3 \
|
| 90 |
+
2>&1 | tee $EXPERIMENT_NAME.log
|
code/RL_model/verl/verl_train/.git-blame-ignore-revs
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Local uasge: git config blame.ignoreRevsFile .git-blame-ignore-revs
|
| 2 |
+
|
| 3 |
+
# [dev] feat: immigrate from yapf & pylint to ruff based on pre-commit
|
| 4 |
+
# Changed 268 files, +10k/-9k lines. This is the biggest formatter change.
|
| 5 |
+
b00f77d8559b48d57a33c0132a5ba1c81891a536
|
| 6 |
+
|
| 7 |
+
# [ci] refactor: reduce ruff line-length from 300 to 120
|
| 8 |
+
# Changed 238 files, +6k/-1k lines. Global formatting change.
|
| 9 |
+
00a10a8ef389556f957a2f36132b2358fd6a109f
|
| 10 |
+
|
| 11 |
+
# [Lint] fix: linting errors in all files
|
| 12 |
+
# Changed 179 files, +1k/-3k lines. Global lint fix.
|
| 13 |
+
8e5ad4688a13de81727c014a3c2e2fb26324bc20
|
code/RL_model/verl/verl_train/.gitignore
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**/*.pt
|
| 2 |
+
**/checkpoints
|
| 3 |
+
**/wget-log
|
| 4 |
+
**/_build/
|
| 5 |
+
**/*.ckpt
|
| 6 |
+
**/outputs
|
| 7 |
+
**/*.tar.gz
|
| 8 |
+
**/playground
|
| 9 |
+
**/wandb
|
| 10 |
+
|
| 11 |
+
# Byte-compiled / optimized / DLL files
|
| 12 |
+
__pycache__/
|
| 13 |
+
*.py[cod]
|
| 14 |
+
*$py.class
|
| 15 |
+
dataset/*
|
| 16 |
+
tensorflow/my_graph/*
|
| 17 |
+
.idea/
|
| 18 |
+
# C extensions
|
| 19 |
+
*.so
|
| 20 |
+
|
| 21 |
+
# Distribution / packaging
|
| 22 |
+
.Python
|
| 23 |
+
# env/
|
| 24 |
+
build/
|
| 25 |
+
develop-eggs/
|
| 26 |
+
dist/
|
| 27 |
+
downloads/
|
| 28 |
+
eggs/
|
| 29 |
+
.eggs/
|
| 30 |
+
lib/
|
| 31 |
+
lib64/
|
| 32 |
+
parts/
|
| 33 |
+
sdist/
|
| 34 |
+
var/
|
| 35 |
+
tmp/
|
| 36 |
+
*.egg-info/
|
| 37 |
+
.installed.cfg
|
| 38 |
+
*.egg
|
| 39 |
+
|
| 40 |
+
# PyInstaller
|
| 41 |
+
# Usually these files are written by a python script from a template
|
| 42 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 43 |
+
*.manifest
|
| 44 |
+
*.spec
|
| 45 |
+
|
| 46 |
+
# Installer logs
|
| 47 |
+
pip-log.txt
|
| 48 |
+
pip-delete-this-directory.txt
|
| 49 |
+
|
| 50 |
+
# Unit test / coverage reports
|
| 51 |
+
htmlcov/
|
| 52 |
+
.tox/
|
| 53 |
+
.coverage
|
| 54 |
+
.coverage.*
|
| 55 |
+
.cache
|
| 56 |
+
nosetests.xml
|
| 57 |
+
coverage.xml
|
| 58 |
+
*,cover
|
| 59 |
+
.hypothesis/
|
| 60 |
+
pytest.ini
|
| 61 |
+
output.txt
|
| 62 |
+
|
| 63 |
+
# Translations
|
| 64 |
+
*.mo
|
| 65 |
+
*.pot
|
| 66 |
+
|
| 67 |
+
# Django stuff:
|
| 68 |
+
*.log
|
| 69 |
+
local_settings.py
|
| 70 |
+
|
| 71 |
+
# Flask stuff:
|
| 72 |
+
instance/
|
| 73 |
+
.webassets-cache
|
| 74 |
+
|
| 75 |
+
# Scrapy stuff:
|
| 76 |
+
.scrapy
|
| 77 |
+
|
| 78 |
+
# Sphinx documentation
|
| 79 |
+
docs/_build/
|
| 80 |
+
|
| 81 |
+
# PyBuilder
|
| 82 |
+
target/
|
| 83 |
+
|
| 84 |
+
# IPython Notebook
|
| 85 |
+
.ipynb_checkpoints
|
| 86 |
+
|
| 87 |
+
# pyenv
|
| 88 |
+
.python-version
|
| 89 |
+
|
| 90 |
+
# celery beat schedule file
|
| 91 |
+
celerybeat-schedule
|
| 92 |
+
|
| 93 |
+
# dotenv
|
| 94 |
+
.env
|
| 95 |
+
|
| 96 |
+
# virtualenv
|
| 97 |
+
venv/
|
| 98 |
+
.venv/
|
| 99 |
+
ENV/
|
| 100 |
+
|
| 101 |
+
# Spyder project settings
|
| 102 |
+
.spyderproject
|
| 103 |
+
|
| 104 |
+
# Rope project settings
|
| 105 |
+
.ropeproject
|
| 106 |
+
|
| 107 |
+
# vscode
|
| 108 |
+
.vscode
|
| 109 |
+
|
| 110 |
+
# Mac
|
| 111 |
+
.DS_Store
|
| 112 |
+
|
| 113 |
+
# vim
|
| 114 |
+
*.swp
|
| 115 |
+
|
| 116 |
+
# emacs
|
| 117 |
+
*~
|
| 118 |
+
|
| 119 |
+
# ckpt
|
| 120 |
+
*.lock
|
| 121 |
+
|
| 122 |
+
# data
|
| 123 |
+
*.parquet
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# local logs
|
| 127 |
+
logs
|
| 128 |
+
log
|
| 129 |
+
outputs
|
| 130 |
+
.history
|
code/RL_model/verl/verl_train/.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "recipe"]
|
| 2 |
+
path = recipe
|
| 3 |
+
url = https://github.com/verl-project/verl-recipe.git
|
code/RL_model/verl/verl_train/.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
code/RL_model/verl/verl_train/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 3 |
+
rev: "v0.12.2"
|
| 4 |
+
hooks:
|
| 5 |
+
- id: ruff
|
| 6 |
+
args: ["--fix", "--show-fixes", "--output-format=full"]
|
| 7 |
+
exclude: ^.*\.(ipynb)$
|
| 8 |
+
- id: ruff-format
|
| 9 |
+
|
| 10 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
| 11 |
+
rev: "v1.17.0"
|
| 12 |
+
hooks:
|
| 13 |
+
- id: mypy
|
| 14 |
+
|
| 15 |
+
- repo: local
|
| 16 |
+
hooks:
|
| 17 |
+
- id: autogen-trainer-cfg
|
| 18 |
+
name: Generate and verify verl/trainer/config/_generated_*.yaml
|
| 19 |
+
entry: scripts/generate_trainer_config.sh
|
| 20 |
+
language: script
|
| 21 |
+
pass_filenames: false
|
| 22 |
+
|
| 23 |
+
- repo: local
|
| 24 |
+
hooks:
|
| 25 |
+
- id: check-docstrings
|
| 26 |
+
name: Check doc string coverage
|
| 27 |
+
entry: python3 tests/special_sanity/check_docstrings.py
|
| 28 |
+
language: python
|
| 29 |
+
pass_filenames: false
|
| 30 |
+
|
| 31 |
+
- repo: local
|
| 32 |
+
hooks:
|
| 33 |
+
- id: check-license
|
| 34 |
+
name: Check license
|
| 35 |
+
entry: python3 tests/special_sanity/check_license.py --directories examples scripts tests verl setup.py
|
| 36 |
+
language: python
|
| 37 |
+
pass_filenames: false
|
| 38 |
+
|
| 39 |
+
- repo: local
|
| 40 |
+
hooks:
|
| 41 |
+
- id: compileall
|
| 42 |
+
name: Compile all python files
|
| 43 |
+
entry: sh -c 'PYTHONWARNINGS=error python3 -m compileall -q .'
|
| 44 |
+
language: python
|
| 45 |
+
pass_filenames: false
|
code/RL_model/verl/verl_train/.readthedocs.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Read the Docs configuration file
|
| 2 |
+
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
| 3 |
+
|
| 4 |
+
version: 2
|
| 5 |
+
|
| 6 |
+
build:
|
| 7 |
+
os: ubuntu-22.04
|
| 8 |
+
tools:
|
| 9 |
+
python: "3.11"
|
| 10 |
+
rust: "1.70"
|
| 11 |
+
|
| 12 |
+
sphinx:
|
| 13 |
+
configuration: docs/conf.py
|
| 14 |
+
|
| 15 |
+
python:
|
| 16 |
+
install:
|
| 17 |
+
- requirements: docs/requirements-docs.txt
|
| 18 |
+
- method: pip
|
| 19 |
+
path: .
|
code/RL_model/verl/verl_train/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to verl
|
| 2 |
+
|
| 3 |
+
Thank you for considering a contribution to verl! We welcome contributions of any kind - bug fixes, enhancements, documentation improvements, or even just feedback. Whether you're an experienced developer or this is your first open-source project, your help is invaluable.
|
| 4 |
+
|
| 5 |
+
Your support can take many forms:
|
| 6 |
+
- Report issues or unexpected behaviors.
|
| 7 |
+
- Suggest or implement new features.
|
| 8 |
+
- Improve or expand documentation.
|
| 9 |
+
- Review pull requests and assist other contributors.
|
| 10 |
+
- Spread the word: share verl in blog posts, social media, or give the repo a ⭐.
|
| 11 |
+
|
| 12 |
+
## Finding Issues to Contribute
|
| 13 |
+
|
| 14 |
+
Looking for ways to dive in? Check out these issues:
|
| 15 |
+
- [Good first issues](https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22)
|
| 16 |
+
- [Call for contribution](https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22call%20for%20contribution%22)
|
| 17 |
+
Furthermore, you can learn the development plan and roadmap via [RFC](https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20label%3ARFC) and [Roadmap](https://github.com/volcengine/verl/issues?q=state%3Aopen%20label%3A%22roadmap%22).
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## Developing
|
| 21 |
+
|
| 22 |
+
- **Python-only**: install verl via `pip install -e .[test,vllm]` or `pip install -e .[test,sglang]` and iterate quickly. For full dependency setup, check out the verl [installation doc](https://verl.readthedocs.io/en/latest/start/install.html).
|
| 23 |
+
|
| 24 |
+
## Code Linting and Formatting
|
| 25 |
+
|
| 26 |
+
We rely on pre-commit to keep our code consistent. To set it up:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
pip install pre-commit
|
| 30 |
+
pre-commit install
|
| 31 |
+
# for staged changes
|
| 32 |
+
pre-commit run
|
| 33 |
+
# for all files in the repo
|
| 34 |
+
pre-commit run --all-files
|
| 35 |
+
# run a specific hook with pre-commit
|
| 36 |
+
# pre-commit run --all-files --show-diff-on-failure --color=always <hood-id>
|
| 37 |
+
pre-commit run --all-files --show-diff-on-failure --color=always ruff
|
| 38 |
+
pre-commit run --all-files --show-diff-on-failure --color=always autogen-trainer-cfg
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Testing
|
| 42 |
+
|
| 43 |
+
Our test suites run on GitHub Actions. Check these workflows for details:
|
| 44 |
+
- [GPU unit tests](https://github.com/volcengine/verl/blob/main/.github/workflows/gpu_unit_tests.yml)
|
| 45 |
+
- [CPU unit tests](https://github.com/volcengine/verl/blob/main/.github/workflows/cpu_unit_tests.yml)
|
| 46 |
+
- [vLLM tests](https://github.com/volcengine/verl/blob/main/.github/workflows/vllm.yml)
|
| 47 |
+
- [SGLang tests](https://github.com/volcengine/verl/blob/main/.github/workflows/sgl.yml)
|
| 48 |
+
|
| 49 |
+
### Adding CI tests
|
| 50 |
+
|
| 51 |
+
If possible, please add CI test(s) for your new feature:
|
| 52 |
+
|
| 53 |
+
1. Find the most relevant workflow yml file, which usually corresponds to a `hydra` default config (e.g. `ppo_trainer`, `ppo_megatron_trainer`, `sft_trainer`, etc).
|
| 54 |
+
2. Add related path patterns to the `paths` section if not already included.
|
| 55 |
+
3. Minimize the workload of the test script(s) (see existing scripts for examples).
|
| 56 |
+
|
| 57 |
+
## Building the Docs
|
| 58 |
+
```
|
| 59 |
+
# Ensure verl is on your PYTHONPATH, e.g.:
|
| 60 |
+
pip install -e .[test]
|
| 61 |
+
|
| 62 |
+
# Install documentation dependencies
|
| 63 |
+
cd docs
|
| 64 |
+
pip install -r requirements-docs.txt
|
| 65 |
+
|
| 66 |
+
# Generate HTML docs
|
| 67 |
+
make clean
|
| 68 |
+
make html
|
| 69 |
+
|
| 70 |
+
# Preview locally
|
| 71 |
+
python -m http.server -d _build/html/
|
| 72 |
+
```
|
| 73 |
+
Open your browser at http://localhost:8000 to explore the docs.
|
| 74 |
+
|
| 75 |
+
## Pull Requests & Code Reviews
|
| 76 |
+
|
| 77 |
+
Thanks for submitting a PR! To streamline reviews:
|
| 78 |
+
- Follow our Pull Request Template for title format and checklist.
|
| 79 |
+
- Adhere to our pre-commit lint rules and ensure all checks pass.
|
| 80 |
+
- Update docs for any user-facing changes.
|
| 81 |
+
- Add or update tests in the CI workflows, or explain why tests aren't applicable.
|
| 82 |
+
|
| 83 |
+
## License
|
| 84 |
+
|
| 85 |
+
See the [LICENSE](https://github.com/volcengine/verl/blob/main/LICENSE) file for full details.
|
| 86 |
+
|
| 87 |
+
## Thank You
|
| 88 |
+
|
| 89 |
+
We appreciate your contributions to verl. Your efforts help make the project stronger and more user-friendly. Happy coding!
|
| 90 |
+
|