File size: 13,974 Bytes
0a7ac4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 | """
Unit Test Suite for TinyBert-CNN Intent Classifier Pipeline.
Tests: model init, dataset tokenization, forward pass, predict, compound splitter,
dataset generator output, and auto_trainer state I/O.
"""
import unittest
import os
import sys
import json
import tempfile
import torch
import pandas as pd
# Ensure the project directory is on sys.path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from TinyBert import IntentClassifier, IntentDataset, CompoundSentenceSplitter, TinyBertCNN
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1. MODEL INITIALIZATION
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestModelInit(unittest.TestCase):
"""Test that the model initializes correctly."""
@classmethod
def setUpClass(cls):
cls.classifier = IntentClassifier(num_classes=5)
def test_model_instance(self):
self.assertIsInstance(self.classifier.model, TinyBertCNN)
def test_num_classes(self):
self.assertEqual(self.classifier.num_classes, 5)
def test_device_assigned(self):
self.assertIsNotNone(self.classifier.device)
def test_tokenizer_loaded(self):
self.assertIsNotNone(self.classifier.tokenizer)
def test_model_has_batchnorm(self):
"""Verify BatchNorm layers were added."""
self.assertTrue(hasattr(self.classifier.model, 'batchnorms'))
self.assertEqual(len(self.classifier.model.batchnorms), 3) # 3 filter sizes
def test_model_has_hidden_fc(self):
"""Verify hidden FC layer exists."""
self.assertTrue(hasattr(self.classifier.model, 'fc_hidden'))
self.assertTrue(hasattr(self.classifier.model, 'bn_hidden'))
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2. INTENT DATASET
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestIntentDataset(unittest.TestCase):
"""Test tokenization and tensor shapes from IntentDataset."""
@classmethod
def setUpClass(cls):
cls.classifier = IntentClassifier(num_classes=5)
cls.sample_data = [
{'student_input': 'How do I use for loops?',
'session_context': 'topic:For Loops | prev:If/Else | ability:If/Else:85% | emotion:engaged | pace:normal | slides:14,15,16',
'label': 0},
{'student_input': "What's the weather?",
'session_context': 'topic:Variables | prev:None | ability:N/A | emotion:bored | pace:slow | slides:5,6,7',
'label': 1},
]
cls.dataset = IntentDataset(cls.sample_data, cls.classifier.tokenizer, max_length=128)
def test_dataset_length(self):
self.assertEqual(len(self.dataset), 2)
def test_output_keys(self):
item = self.dataset[0]
self.assertIn('input_ids', item)
self.assertIn('attention_mask', item)
self.assertIn('labels', item)
def test_tensor_shapes(self):
item = self.dataset[0]
self.assertEqual(item['input_ids'].shape, torch.Size([128]))
self.assertEqual(item['attention_mask'].shape, torch.Size([128]))
def test_label_type(self):
item = self.dataset[0]
self.assertEqual(item['labels'].dtype, torch.long)
def test_token_type_ids_present(self):
"""TinyBERT should produce token_type_ids for sentence pairs."""
item = self.dataset[0]
if 'token_type_ids' in item:
self.assertEqual(item['token_type_ids'].shape, torch.Size([128]))
def test_handles_string_labels(self):
data = [{'student_input': 'test', 'session_context': 'ctx', 'label': 'Pace-Related'}]
ds = IntentDataset(data, self.classifier.tokenizer)
item = ds[0]
self.assertEqual(item['labels'].item(), 3)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3. FORWARD PASS
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestForwardPass(unittest.TestCase):
"""Test the TinyBertCNN forward pass with dummy data."""
@classmethod
def setUpClass(cls):
cls.classifier = IntentClassifier(num_classes=5)
def test_output_shape(self):
batch_size = 4
seq_len = 128
input_ids = torch.randint(0, 1000, (batch_size, seq_len)).to(self.classifier.device)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long).to(self.classifier.device)
self.classifier.model.eval()
with torch.no_grad():
logits = self.classifier.model(input_ids, attention_mask)
self.assertEqual(logits.shape, torch.Size([batch_size, 5]))
def test_output_with_token_type_ids(self):
batch_size = 2
seq_len = 128
input_ids = torch.randint(0, 1000, (batch_size, seq_len)).to(self.classifier.device)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long).to(self.classifier.device)
token_type_ids = torch.zeros(batch_size, seq_len, dtype=torch.long).to(self.classifier.device)
self.classifier.model.eval()
with torch.no_grad():
logits = self.classifier.model(input_ids, attention_mask, token_type_ids=token_type_ids)
self.assertEqual(logits.shape, torch.Size([batch_size, 5]))
def test_single_sample(self):
"""Ensure single-sample batches don't crash (important for BatchNorm)."""
input_ids = torch.randint(0, 1000, (1, 128)).to(self.classifier.device)
attention_mask = torch.ones(1, 128, dtype=torch.long).to(self.classifier.device)
self.classifier.model.eval()
with torch.no_grad():
logits = self.classifier.model(input_ids, attention_mask)
self.assertEqual(logits.shape, torch.Size([1, 5]))
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4. PREDICT
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestPredict(unittest.TestCase):
"""Test the predict() method with real text."""
@classmethod
def setUpClass(cls):
cls.classifier = IntentClassifier(num_classes=5)
def test_predict_with_context(self):
preds, probs = self.classifier.predict(
["How do loops work?"],
["topic:For Loops | prev:None | ability:N/A | emotion:neutral | pace:normal | slides:10,11,12"]
)
self.assertEqual(len(preds), 1)
self.assertEqual(probs.shape[1], 5)
def test_predict_without_context(self):
preds, probs = self.classifier.predict(["I'm feeling frustrated"])
self.assertEqual(len(preds), 1)
def test_predict_empty_string(self):
"""Empty input should not crash."""
preds, probs = self.classifier.predict([""])
self.assertEqual(len(preds), 1)
def test_predict_multiple(self):
preds, probs = self.classifier.predict(
["Hello", "Can you repeat?", "Speed up please"],
["ctx1", "ctx2", "ctx3"]
)
self.assertEqual(len(preds), 3)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5. COMPOUND SENTENCE SPLITTER
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestCompoundSplitter(unittest.TestCase):
"""Test the CompoundSentenceSplitter edge cases."""
@classmethod
def setUpClass(cls):
cls.splitter = CompoundSentenceSplitter()
def test_compound_question_splits(self):
result = self.splitter.split_compound_question(
"What is a variable and how do I use it?"
)
self.assertGreaterEqual(len(result), 2)
def test_single_question_no_split(self):
result = self.splitter.split_compound_question("How do loops work?")
self.assertEqual(len(result), 1)
def test_non_question_no_split(self):
result = self.splitter.split_compound_question("I like programming.")
self.assertEqual(len(result), 1)
def test_multiple_question_marks(self):
result = self.splitter.split_compound_question("What is a loop? How does it work?")
self.assertEqual(len(result), 2)
def test_empty_string(self):
result = self.splitter.split_compound_question("")
self.assertEqual(len(result), 1)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 6. DATASET GENERATOR
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestDatasetGenerator(unittest.TestCase):
"""Test that the dataset generator produces correct output."""
@classmethod
def setUpClass(cls):
# Generate a small dataset
from dataset_generator import build_dataset
cls.original_dir = os.getcwd()
cls.tmp_dir = tempfile.mkdtemp()
os.chdir(cls.tmp_dir)
build_dataset(num_samples_per_class=20)
cls.train_df = pd.read_csv('data/train.csv')
cls.val_df = pd.read_csv('data/val.csv')
cls.test_df = pd.read_csv('data/test.csv')
@classmethod
def tearDownClass(cls):
os.chdir(cls.original_dir)
def test_columns_exist(self):
for col in ['student_input', 'session_context', 'label', 'intent_name']:
self.assertIn(col, self.train_df.columns)
def test_three_splits_exist(self):
self.assertGreater(len(self.train_df), 0)
self.assertGreater(len(self.val_df), 0)
self.assertGreater(len(self.test_df), 0)
def test_all_classes_present(self):
all_labels = set(self.train_df['label'].unique())
self.assertEqual(all_labels, {0, 1, 2, 3, 4})
def test_compact_context_format(self):
ctx = self.train_df.iloc[0]['session_context']
self.assertIn('topic:', ctx)
self.assertIn('prev:', ctx)
self.assertIn('emotion:', ctx)
def test_no_empty_inputs(self):
self.assertFalse(self.train_df['student_input'].isna().any())
self.assertFalse(self.train_df['session_context'].isna().any())
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 7. AUTO TRAINER STATE
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestAutoTrainerState(unittest.TestCase):
"""Test load_state / save_state round-trip."""
def test_state_round_trip(self):
from auto_trainer import load_state, save_state, STATE_FILE
# Save original if exists
original_exists = os.path.exists(STATE_FILE)
original_content = None
if original_exists:
with open(STATE_FILE, 'r') as f:
original_content = f.read()
try:
test_state = {"sessions_since_last_train": 42, "total_sessions": 100}
save_state(test_state)
loaded = load_state()
self.assertEqual(loaded["sessions_since_last_train"], 42)
self.assertEqual(loaded["total_sessions"], 100)
finally:
# Restore original
if original_exists:
with open(STATE_FILE, 'w') as f:
f.write(original_content)
elif os.path.exists(STATE_FILE):
os.remove(STATE_FILE)
def test_default_state(self):
from auto_trainer import load_state, STATE_FILE
backup = None
if os.path.exists(STATE_FILE):
with open(STATE_FILE, 'r') as f:
backup = f.read()
os.remove(STATE_FILE)
try:
state = load_state()
self.assertEqual(state["sessions_since_last_train"], 0)
self.assertEqual(state["total_sessions"], 0)
finally:
if backup:
with open(STATE_FILE, 'w') as f:
f.write(backup)
if __name__ == '__main__':
unittest.main(verbosity=2)
|