| import catboost |
| import lightgbm |
|
|
| import os |
| import sys |
| import re |
| import json |
| import inspect |
| import shutil |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from collections import Counter |
| from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig |
| from transformers.modeling_outputs import SequenceClassifierOutput |
| from huggingface_hub import snapshot_download, login, create_repo, HfApi |
| from sklearn.preprocessing import LabelEncoder |
|
|
| class EssayRegressionHead(nn.Module): |
| def __init__(self, hidden_size, dropout=0.15): |
| super().__init__() |
| self.dropout = nn.Dropout(dropout) |
| self.fc1 = nn.Linear(hidden_size * 2, 512) |
| self.act = nn.GELU() |
| self.fc2 = nn.Linear(512, 1) |
| def forward(self, hidden_states, attention_mask): |
| cls_emb = hidden_states[:, 0, :] |
| mask = attention_mask.unsqueeze(-1).float() |
| mean_emb = (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9) |
| pooled = torch.cat([cls_emb, mean_emb], dim=-1) |
| return self.fc2(self.act(self.fc1(self.dropout(pooled)))) |
|
|
| class EssayRegressorModel(nn.Module): |
| def __init__(self, config_or_name, dropout=0.15, **kwargs): |
| super().__init__() |
| if isinstance(config_or_name, PretrainedConfig): |
| self.config = config_or_name |
| self.backbone = AutoModel.from_config(self.config) |
| else: |
| self.config = AutoConfig.from_pretrained(config_or_name, **kwargs) |
| self.backbone = AutoModel.from_pretrained(config_or_name, config=self.config) |
| self.head = EssayRegressionHead(self.config.hidden_size, dropout) |
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
| out = self.backbone(input_ids=input_ids, attention_mask=attention_mask, **kwargs) |
| logits = self.head(out.last_hidden_state, attention_mask) |
| loss = None |
| if labels is not None: |
| loss = torch.nn.functional.huber_loss(logits.squeeze(-1), labels.float(), delta=1.0) |
| return SequenceClassifierOutput(loss=loss, logits=logits) |
| def state_dict(self, destination=None, prefix='', keep_vars=False): |
| state = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) |
| return {k: v.contiguous() if isinstance(v, torch.Tensor) else v for k, v in state.items()} |
| @classmethod |
| def from_pretrained(cls, repo_id_or_path, base_model_name="google/electra-large-discriminator", dropout=0.15): |
| if "/" in repo_id_or_path and not os.path.exists(repo_id_or_path): |
| local_path = snapshot_download(repo_id=repo_id_or_path) |
| else: |
| local_path = repo_id_or_path |
| config = AutoConfig.from_pretrained(local_path) |
| if config.hidden_size != 1024: |
| config = AutoConfig.from_pretrained(base_model_name) |
| model = cls(config, dropout=dropout) |
| weights_path = os.path.join(local_path, "pytorch_model.bin") |
| if not os.path.exists(weights_path): |
| weights_path = os.path.join(local_path, "model.safetensors") |
| if os.path.exists(weights_path): |
| if weights_path.endswith(".bin"): |
| state_dict = torch.load(weights_path, map_location="cpu") |
| else: |
| from safetensors.torch import load_file |
| state_dict = load_file(weights_path) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| if missing: |
| print("MISSING:", missing[:3]) |
| if unexpected: |
| print("UNEXPECTED:", unexpected[:3]) |
| else: |
| raise FileNotFoundError("weights not found in " + local_path) |
| return model |
|
|
| HEAD_HIDDEN = 512 |
|
|
| class CompactRegressionHead(nn.Module): |
| def __init__(self, hidden_size, dropout=0.1): |
| super().__init__() |
| self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12) |
| self.dropout1 = nn.Dropout(dropout) |
| self.fc1 = nn.Linear(hidden_size, HEAD_HIDDEN) |
| self.act = nn.GELU() |
| self.dropout2 = nn.Dropout(dropout) |
| self.fc2 = nn.Linear(HEAD_HIDDEN, 1) |
| def forward(self, pooled_output): |
| x = self.layer_norm(pooled_output) |
| x = self.dropout1(x) |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.dropout2(x) |
| return self.fc2(x) |
|
|
| class ModernBERTRegressorModel(nn.Module): |
| def __init__(self, config_or_name, dropout=0.1, **kwargs): |
| super().__init__() |
| if isinstance(config_or_name, PretrainedConfig): |
| self.config = config_or_name |
| self.backbone = AutoModel.from_config(self.config) |
| else: |
| self.config = AutoConfig.from_pretrained(config_or_name, **kwargs) |
| self.backbone = AutoModel.from_pretrained(config_or_name, config=self.config) |
| self.head = CompactRegressionHead(self.config.hidden_size, dropout) |
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
| out = self.backbone(input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False, **kwargs) |
| mask = attention_mask.unsqueeze(-1).float() |
| pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9) |
| logits = self.head(pooled) |
| loss = None |
| if labels is not None: |
| loss = nn.functional.huber_loss(logits.squeeze(-1), labels.float(), delta=1.0) |
| return SequenceClassifierOutput(loss=loss, logits=logits) |
| @classmethod |
| def from_pretrained(cls, repo_id_or_path, base_model_name="answerdotai/ModernBERT-base", dropout=0.1): |
| if "/" in repo_id_or_path and not os.path.exists(repo_id_or_path): |
| local_path = snapshot_download(repo_id=repo_id_or_path) |
| else: |
| local_path = repo_id_or_path |
| config_path = os.path.join(local_path, "config.json") |
| config = AutoConfig.from_pretrained(local_path) if os.path.exists(config_path) else AutoConfig.from_pretrained(base_model_name) |
| model = cls(config, dropout=dropout) |
| safetensors_path = os.path.join(local_path, "model.safetensors") |
| bin_path = os.path.join(local_path, "pytorch_model.bin") |
| if os.path.exists(safetensors_path): |
| from safetensors.torch import load_file |
| state_dict = load_file(safetensors_path) |
| elif os.path.exists(bin_path): |
| state_dict = torch.load(bin_path, map_location="cpu") |
| else: |
| raise FileNotFoundError("weights not found in " + local_path) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| if missing: |
| print("MISSING:", missing[:3]) |
| if unexpected: |
| print("UNEXPECTED:", unexpected[:3]) |
| return model |
|
|
| class TextCNNRegressor(nn.Module): |
| def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, dropout=0.3): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) |
| self.convs = nn.ModuleList([nn.Conv1d(in_channels=embed_dim, out_channels=num_filters, kernel_size=fs) for fs in filter_sizes]) |
| self.dropout = nn.Dropout(dropout) |
| self.fc = nn.Linear(num_filters * len(filter_sizes), 1) |
| self.config = type('obj', (object,), {'hidden_size': num_filters * len(filter_sizes)})() |
| def forward(self, input_ids, labels=None): |
| x = self.embedding(input_ids).permute(0, 2, 1) |
| cnn_features = [] |
| for conv in self.convs: |
| feat_map = torch.nn.functional.relu(conv(x)) |
| pooled = torch.nn.functional.max_pool1d(feat_map, feat_map.shape[2]).squeeze(2) |
| cnn_features.append(pooled) |
| x = torch.cat(cnn_features, dim=1) |
| x = self.dropout(x) |
| logits = self.fc(x) |
| loss = None |
| if labels is not None: |
| loss = nn.functional.huber_loss(logits.squeeze(-1), labels.float(), delta=1.0) |
| return SequenceClassifierOutput(loss=loss, logits=logits) |
| @classmethod |
| def from_pretrained(cls, repo_id_or_path): |
| from safetensors.torch import load_file |
| if "/" in repo_id_or_path and not os.path.exists(repo_id_or_path): |
| local_path = snapshot_download(repo_id=repo_id_or_path) |
| else: |
| local_path = repo_id_or_path |
| with open(os.path.join(local_path, "textcnn_config.json")) as f: |
| cfg = json.load(f) |
| model = cls(vocab_size=cfg["vocab_size"], embed_dim=cfg["embed_dim"], num_filters=cfg["num_filters"], filter_sizes=cfg["filter_sizes"], dropout=cfg["dropout"]) |
| safetensors_path = os.path.join(local_path, "model.safetensors") |
| bin_path = os.path.join(local_path, "pytorch_model.bin") |
| if os.path.exists(safetensors_path): |
| state_dict = load_file(safetensors_path) |
| elif os.path.exists(bin_path): |
| state_dict = torch.load(bin_path, map_location="cpu") |
| else: |
| raise FileNotFoundError("weights not found in " + local_path) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| if missing: |
| print("MISSING:", missing[:3]) |
| if unexpected: |
| print("UNEXPECTED:", unexpected[:3]) |
| return model, cfg.get("tokenizer_name", "bert-base-uncased") |
|
|
| def _count_sentences(text): |
| return len(re.findall(r'[.!?]+', str(text))) + 1 |
|
|
| def _avg_word_length(text): |
| words = str(text).split() |
| return np.mean([len(w) for w in words]) if words else 0.0 |
|
|
| def _lexical_diversity(text): |
| words = str(text).lower().split() |
| return len(set(words)) / len(words) if words else 0.0 |
|
|
| def _count_paragraphs(text): |
| return len([p for p in str(text).split('\n') if p.strip()]) |
|
|
| def _count_punctuation(text): |
| return sum(1 for c in str(text) if c in '.,;:!?()[]{}"\'-') |
|
|
| def _count_connectives(text): |
| connectives = ['however', 'therefore', 'furthermore', 'moreover', 'although', 'nevertheless', 'consequently', 'in addition', 'for example', 'in conclusion', 'on the other hand', 'as a result', 'thus', 'hence', 'meanwhile', 'subsequently', 'additionally'] |
| text_lower = str(text).lower() |
| return sum(text_lower.count(c) for c in connectives) |
|
|
| def _count_spelling_errors(text): |
| consonants = set('bcdfghjklmnpqrstvwxyz') |
| count = 0 |
| for word in str(text).lower().split(): |
| run = 0 |
| for ch in word: |
| if ch in consonants: |
| run += 1 |
| if run >= 4: |
| count += 1 |
| break |
| else: |
| run = 0 |
| return count |
|
|
| def _source_overlap(essay, source): |
| if not source or pd.isna(source): |
| return 0.0 |
| essay_words = set(str(essay).lower().split()) |
| source_words = set(str(source).lower().split()) |
| return len(essay_words & source_words) / len(essay_words) if essay_words else 0.0 |
|
|
| def _count_common_misspellings(text): |
| text_lower = str(text).lower() |
| error_patterns = [r'\bprinciple\b', r'\baloud\b', r'\bu\b', r'\bur\b', r'\bthier\b', r'\bteh\b', r'\btaht\b', r'\bwhta\b', r'\bdont\b', r'\bcant\b', r'\bwont\b', r'\bdoesnt\b', r'\bwasnt\b', r'\bwerent\b', r'\bhasnt\b', r'\bhavent\b', r'\bshouldnt\b', r'\bcouldnt\b', r'\bwouldnt\b', r'\bim\b', r'\bive\b'] |
| count = sum(len(re.findall(p, text_lower)) for p in error_patterns) |
| sentences = re.split(r'[.!?]+', text_lower) |
| bigrams = [] |
| for sent in sentences: |
| words = sent.split() |
| for i in range(len(words) - 1): |
| bigrams.append((words[i], words[i + 1])) |
| repeated_bigrams = sum(1 for v in Counter(bigrams).values() if v > 2) |
| return count + repeated_bigrams |
|
|
| def _essay_structure_score(text): |
| text_lower = str(text).lower() |
| has_greeting = bool(re.search(r'\b(dear|to\s+the|hello|hi)\b', text_lower[:100])) |
| has_conclusion = bool(re.search(r'\b(in\s+conclusion|to\s+conclude|in\s+summary|overall|therefore|thus)\b', text_lower[-300:])) |
| body_markers = len(re.findall(r'\b(first|second|third|fourth|fifth|next|also|another|finally|lastly)\b', text_lower)) |
| has_closing = bool(re.search(r'\b(sincerely|thank\s+you|yours\s+truly|best\s+regards)\b', text_lower[-200:])) |
| return min(has_greeting * 0.25 + has_conclusion * 0.25 + min(body_markers, 5) * 0.1 + has_closing * 0.25, 1.0) |
|
|
| def _argument_quality_score(text): |
| text_lower = str(text).lower() |
| evidence = len(re.findall(r'\b(for\s+example|such\s+as|according\s+to|research\s+shows|studies\s+show|data|statistics|percent|%)\b', text_lower)) |
| specificity = len(re.findall(r'\b\d+\b', text_lower)) |
| personal = len(re.findall(r'\b(i\s+think|i\s+believe|in\s+my\s+opinion|from\s+my\s+experience|i\s+have\s+seen|i\s+know)\b', text_lower)) |
| words = text_lower.split() |
| unique_ratio = len(set(words)) / len(words) if words else 0 |
| return min(min(evidence, 3) * 0.2 + min(specificity, 5) * 0.1 + min(personal, 3) * 0.15 + unique_ratio * 0.55, 1.0) |
|
|
| def _readability_features(text): |
| sentences = [s.strip() for s in re.split(r'[.!?]+', str(text)) if s.strip()] |
| words = str(text).split() |
| if not sentences or not words: |
| return {'avg_sentence_length': 0, 'avg_syllables': 0, 'flesch_score': 0} |
| def count_syllables(word): |
| word = word.lower().strip('.,;:!?"\'') |
| if not word: |
| return 0 |
| vowels = 'aeiouy' |
| count, prev_was_vowel = 0, False |
| for char in word: |
| if char in vowels: |
| if not prev_was_vowel: |
| count += 1 |
| prev_was_vowel = True |
| else: |
| prev_was_vowel = False |
| if word.endswith('e'): |
| count -= 1 |
| return max(count, 1) |
| total_syllables = sum(count_syllables(w) for w in words) |
| avg_sentence_length = len(words) / len(sentences) |
| avg_syllables = total_syllables / len(words) |
| flesch = 206.835 - 1.015 * avg_sentence_length - 84.6 * avg_syllables if avg_sentence_length > 0 else 0 |
| return {'avg_sentence_length': avg_sentence_length, 'avg_syllables': avg_syllables, 'flesch_score': flesch} |
|
|
| def _sentence_length_std(t): |
| lengths = [len(s.split()) for s in re.split(r'[.!?]+', str(t)) if s.strip()] |
| return np.std(lengths) if lengths else 0 |
|
|
| def build_features(df): |
| feat = pd.DataFrame() |
| text = df['full_text'].fillna('') |
| source = df['source_text'].fillna('') if 'source_text' in df.columns else pd.Series([''] * len(df)) |
| feat['char_count'] = text.str.len() |
| feat['word_count'] = text.str.split().str.len() |
| feat['sentence_count'] = text.apply(_count_sentences) |
| feat['paragraph_count'] = text.apply(_count_paragraphs) |
| feat['avg_word_len'] = text.apply(_avg_word_length) |
| feat['avg_sentence_len'] = feat['word_count'] / feat['sentence_count'].clip(lower=1) |
| feat['avg_paragraph_len'] = feat['word_count'] / feat['paragraph_count'].clip(lower=1) |
| feat['lexical_diversity'] = text.apply(_lexical_diversity) |
| feat['punctuation_count'] = text.apply(_count_punctuation) |
| feat['punct_per_word'] = feat['punctuation_count'] / feat['word_count'].clip(lower=1) |
| feat['connective_count'] = text.apply(_count_connectives) |
| feat['connective_per_sent'] = feat['connective_count'] / feat['sentence_count'].clip(lower=1) |
| feat['spelling_proxy'] = text.apply(_count_spelling_errors) |
| feat['source_overlap'] = [_source_overlap(e, s) for e, s in zip(text, source)] |
| feat['has_source'] = (source.str.len() > 10).astype(int) |
| le = LabelEncoder() |
| feat['task_enc'] = le.fit_transform(df['task'].fillna('unknown')) if 'task' in df.columns else 0 |
| feat['prompt_enc'] = le.fit_transform(df['prompt_name'].fillna('unknown')) if 'prompt_name' in df.columns else 0 |
| feat['log_word_count'] = np.log1p(feat['word_count']) |
| feat['log_char_count'] = np.log1p(feat['char_count']) |
| feat['word_count_sq'] = feat['word_count'] ** 2 |
| feat['lex_div_sq'] = feat['lexical_diversity'] ** 2 |
| feat['misspelling_count'] = text.apply(_count_common_misspellings) |
| feat['misspelling_rate'] = feat['misspelling_count'] / feat['word_count'].clip(lower=1) |
| feat['structure_score'] = text.apply(_essay_structure_score) |
| feat['argument_quality'] = text.apply(_argument_quality_score) |
| readability = text.apply(_readability_features).apply(pd.Series) |
| feat = pd.concat([feat, readability], axis=1) |
| feat['char_per_word'] = feat['char_count'] / feat['word_count'].clip(lower=1) |
| feat['sent_per_paragraph'] = feat['sentence_count'] / feat['paragraph_count'].clip(lower=1) |
| feat['long_words_ratio'] = text.apply(lambda x: sum(1 for w in str(x).split() if len(w) > 6) / max(len(str(x).split()), 1)) |
| feat['repeated_words_ratio'] = text.apply(lambda x: 1 - len(set(str(x).lower().split())) / max(len(str(x).split()), 1)) |
| feat['sentence_length_std'] = text.apply(_sentence_length_std) |
| feat['formal_markers'] = text.apply(lambda x: sum(1 for m in ['dear', 'sincerely', 'thank you', 'yours truly', 'regards', 'to the principal', 'to the teacher'] if m in str(x).lower())) |
| feat['informal_markers'] = text.apply(lambda x: sum(1 for m in ['lol', 'omg', 'btw', 'gonna', 'wanna', 'gotta', 'kinda', 'sorta', 'dunno', 'lemme', 'gimme', 'ya', 'yea', 'nah', 'nope', 'whatever'] if m in str(x).lower())) |
| feat['grammar_errors'] = text.apply(lambda x: (len(re.findall(r'\bthere\s+(phones?|cell|friends?|parents?|teachers?|students?|schools?)', str(x).lower())) + len(re.findall(r'\byour\s+(going|gonna|coming|doing)', str(x).lower())) + len(re.findall(r'\b(principle|aloud|thier|teh|taht|whta)\b', str(x).lower())))) |
| feat['discourse_markers'] = text.apply(lambda x: sum(len(re.findall(r'\b' + m + r'\b', str(x).lower())) for m in ['first', 'second', 'third', 'next', 'also', 'another', 'finally', 'lastly', 'however', 'therefore', 'furthermore', 'moreover', 'although', 'nevertheless', 'consequently', 'in addition', 'for example', 'in conclusion', 'on the other hand', 'as a result', 'thus', 'hence', 'meanwhile', 'subsequently', 'additionally', 'ultimately', 'overall', 'in summary', 'to sum up'])) |
| feat['policy_mentions'] = text.apply(lambda x: len(re.findall(r'policy\s*1|policy one|first policy|policy\s*2|policy two|second policy', str(x).lower()))) |
| feat['emergency_mentions'] = text.apply(lambda x: len(re.findall(r'emergency|911|police|ambulance|fire', str(x).lower()))) |
| feat['parent_mentions'] = text.apply(lambda x: len(re.findall(r'parent|mom|dad|mother|father|guardian', str(x).lower()))) |
| feat['cheating_mentions'] = text.apply(lambda x: len(re.findall(r'cheat|cheating|plagiariz', str(x).lower()))) |
| feat['distraction_mentions'] = text.apply(lambda x: len(re.findall(r'distract|disrupt|interrupt|noise', str(x).lower()))) |
| feat['safety_mentions'] = text.apply(lambda x: len(re.findall(r'safe|safety|secure|protect|danger', str(x).lower()))) |
| feat['responsibility_mentions'] = text.apply(lambda x: len(re.findall(r'responsib|trust|mature|adult', str(x).lower()))) |
| feat['repetition_score'] = text.apply(lambda x: sum(1 for v in Counter(str(x).lower().split()).values() if v > 3) / max(len(str(x).split()), 1)) |
| feat['capitalization_ratio'] = text.apply(lambda x: sum(1 for c in str(x) if c.isupper()) / max(len(str(x)), 1)) |
| feat['exclamation_ratio'] = text.apply(lambda x: str(x).count('!') / max(len(str(x)), 1)) |
| feat['question_ratio'] = text.apply(lambda x: str(x).count('?') / max(len(str(x)), 1)) |
| feat['comma_ratio'] = text.apply(lambda x: str(x).count(',') / max(len(str(x)), 1)) |
| feat['unique_word_ratio'] = text.apply(lambda x: len(set(str(x).lower().split())) / max(len(str(x).split()), 1)) |
| return feat.reset_index(drop=True) |
|
|
| class EssayEnsembleModel(nn.Module): |
| MODEL_KEYS = ["electra", "modernbert", "catboost", "textcnn"] |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.electra_tokenizer = None |
| self.modernbert_tokenizer = None |
| self.textcnn_tokenizer = None |
| self.electra_model = None |
| self.modernbert_model = None |
| self.textcnn_model = None |
| self.cat_model = None |
| self.weights = config["weights"] |
| self.score_min = config.get("score_min", 1.0) |
| self.score_max = config.get("score_max", 6.0) |
| def load_all(self): |
| print("loading electra") |
| repo = self.config["electra_repo"] |
| self.electra_tokenizer = AutoTokenizer.from_pretrained(repo) |
| self.electra_model = EssayRegressorModel.from_pretrained(repo, base_model_name="google/electra-large-discriminator") |
| self.electra_model.to(self.device).eval() |
| print("loading modernbert") |
| repo = self.config["modernbert_repo"] |
| self.modernbert_tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") |
| self.modernbert_model = ModernBERTRegressorModel.from_pretrained(repo, base_model_name="answerdotai/ModernBERT-base") |
| self.modernbert_model.to(self.device).eval() |
| print("loading catboost") |
| catboost_local = snapshot_download(repo_id=self.config["catboost_repo"]) |
| sys.path.insert(0, catboost_local) |
| from modeling_catboost import EssayCatBoostModel |
| self.cat_model = EssayCatBoostModel.from_pretrained(catboost_local) |
| print("loading textcnn") |
| self.textcnn_model, tokenizer_name = TextCNNRegressor.from_pretrained(self.config["textcnn_repo"]) |
| self.textcnn_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| self.textcnn_model.to(self.device).eval() |
| print("all models loaded") |
| def _predict_transformer(self, model, tokenizer, texts, max_len=512, batch_size=8, use_sliding_window=False): |
| all_preds = [] |
| @torch.no_grad() |
| def _run_batch(batch_inputs): |
| inputs = {k: v.to(self.device) for k, v in batch_inputs.items()} |
| accepted = set(inspect.signature(model.forward).parameters.keys()) |
| inputs = {k: v for k, v in inputs.items() if k in accepted} |
| out = model(**inputs) |
| return out.logits.squeeze(-1).cpu().numpy() |
| for i in range(0, len(texts), batch_size): |
| batch_texts = texts[i:i + batch_size] |
| batch_logits = [] |
| for text in batch_texts: |
| if use_sliding_window: |
| token_ids = tokenizer.encode(str(text), add_special_tokens=False) |
| effective_len = max_len - 2 |
| stride = effective_len // 2 |
| max_windows = 4 |
| if len(token_ids) <= effective_len: |
| windows = [str(text)] |
| else: |
| windows = [] |
| start = 0 |
| while start < len(token_ids) and len(windows) < max_windows: |
| chunk = token_ids[start:start + effective_len] |
| windows.append(tokenizer.decode(chunk, skip_special_tokens=True)) |
| start += stride |
| inputs = tokenizer(windows, padding=True, truncation=True, max_length=max_len, return_tensors="pt") |
| logit = _run_batch(inputs).mean() |
| else: |
| inputs = tokenizer([str(text)], padding=True, truncation=True, max_length=max_len, return_tensors="pt") |
| logit = _run_batch(inputs).item() |
| batch_logits.append(logit) |
| all_preds.extend(batch_logits) |
| return np.array(all_preds) |
| def _predict_catboost(self, df): |
| feats = build_features(df) |
| texts = df["full_text"].tolist() |
| feats['modernbert_pred'] = self._predict_transformer(self.modernbert_model, self.modernbert_tokenizer, texts, max_len=1024, use_sliding_window=False) |
| feats['ridge_pred'] = 0.0 |
| return self.cat_model.predict(feats) |
| def get_all_predictions(self, df): |
| texts = df["full_text"].tolist() |
| preds = {} |
| print("electra") |
| preds["electra"] = self._predict_transformer(self.electra_model, self.electra_tokenizer, texts, max_len=512, use_sliding_window=True) |
| print("modernbert") |
| modernbert_preds = self._predict_transformer(self.modernbert_model, self.modernbert_tokenizer, texts, max_len=1024, use_sliding_window=False) |
| preds["modernbert"] = modernbert_preds |
| print("catboost") |
| feats = build_features(df) |
| feats['modernbert_pred'] = modernbert_preds |
| feats['ridge_pred'] = 0.0 |
| preds["catboost"] = self.cat_model.predict(feats) |
| print("textcnn") |
| preds["textcnn"] = self._predict_transformer(self.textcnn_model, self.textcnn_tokenizer, texts, max_len=512, use_sliding_window=False) |
| return preds |
| def predict(self, df): |
| if not isinstance(df, pd.DataFrame): |
| raise ValueError("input must be pandas DataFrame") |
| print("getting predictions") |
| preds = self.get_all_predictions(df) |
| w = self.weights |
| final = sum(w[k] * preds[k] for k in self.MODEL_KEYS if w.get(k, 0) > 0) |
| return np.clip(final, self.score_min, self.score_max) |
|
|