Spaces:
Sleeping
Sleeping
File size: 14,951 Bytes
1f71502 8176e08 be1bc6c 1f71502 8176e08 1f71502 8176e08 1f71502 307b502 1f71502 307b502 1f71502 8176e08 307b502 1f71502 6441bee 8176e08 1f71502 be1bc6c 1f71502 8176e08 1f71502 8176e08 1f71502 8176e08 1f71502 8176e08 1f71502 8176e08 be1bc6c 8176e08 1f71502 8176e08 1f71502 8176e08 1f71502 8176e08 1f71502 8176e08 1f71502 8176e08 1f71502 307b502 1f71502 8176e08 1f71502 307b502 8176e08 307b502 1f71502 8176e08 1f71502 307b502 1f71502 8176e08 1f71502 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 |
"""
Prediction generation module for Talmud language classifier
Generates predictions for all dafim using a trained model
"""
import torch
import requests
import os
import re
import warnings
from train import TalmudClassifierLSTM, TalmudDataset, MAX_LEN
# Preprocessing regex to match Vercel's preprocessing exactly
# Vercel uses: /[\u0591-\u05C7]|[,\-?!:\.״]+|<[^>]+>/g
PREPROCESSING_REGEX = re.compile(r'[\u0591-\u05C7]|[,\-?!:\.״]+|<[^>]+>')
def preprocess_text(text: str) -> tuple[str, dict, dict]:
"""
Preprocess text by removing nikud, punctuation, and HTML tags.
Matches Vercel's preprocessing exactly.
Returns (preprocessed_text, prep_to_orig, orig_to_prep) where:
- prep_to_orig maps preprocessed position -> original position
- orig_to_prep maps original position -> preprocessed position (or -1 if removed)
"""
preprocessed = ''
prep_to_orig = {} # Maps preprocessed_pos -> original_pos
orig_to_prep = {} # Maps original_pos -> preprocessed_pos (or -1 if removed)
preprocessed_pos = 0
i = 0
# Process text character by character, handling HTML tags as units
while i < len(text):
# Check for HTML tags (they are removed as units)
if text[i] == '<':
# Find the end of the HTML tag
tag_end = text.find('>', i)
if tag_end != -1:
# Mark all characters in the tag as removed
for orig_pos in range(i, tag_end + 1):
orig_to_prep[orig_pos] = -1
i = tag_end + 1
continue
char = text[i]
char_code = ord(char)
# Check if character should be removed:
# 1. Nikud range: \u0591-\u05C7 (0x0591 to 0x05C7)
# 2. Punctuation: , - ? ! : . ״
should_remove = (
(0x0591 <= char_code <= 0x05C7) or
char in [',', '-', '?', '!', ':', '.', '״']
)
if should_remove:
orig_to_prep[i] = -1 # Mark as removed
else:
prep_to_orig[preprocessed_pos] = i
orig_to_prep[i] = preprocessed_pos
preprocessed += char
preprocessed_pos += 1
i += 1
return preprocessed, prep_to_orig, orig_to_prep
def fetch_daf_texts(vercel_base_url: str, auth_token: str) -> list:
"""
Fetch all daf texts from Vercel API endpoint.
Returns list of daf objects with id and text_content.
Args:
vercel_base_url: Base URL of the Vercel app
auth_token: Authentication token for Vercel API (TRAINING_CALLBACK_TOKEN)
"""
url = f"{vercel_base_url}/api/dafim-texts"
print(f"Fetching daf texts from {url}...")
try:
# Include authentication token in header
headers = {
'x-auth-token': auth_token,
'Content-Type': 'application/json'
}
response = requests.get(url, headers=headers, timeout=60)
response.raise_for_status()
data = response.json()
print(f"Fetched {data.get('count', 0)} dafim")
return data.get('dafim', [])
except Exception as e:
print(f"Error fetching daf texts: {e}")
if hasattr(e, 'response') and e.response is not None:
print(f"Response status: {e.response.status_code}")
print(f"Response text: {e.response.text}")
raise
def text_to_sequence(text: str, word_to_idx: dict) -> list:
"""Convert text to sequence of word indices"""
# Validate that required keys exist
if '<UNK>' not in word_to_idx:
raise ValueError("Vocabulary must contain '<UNK>' key")
if '<PAD>' not in word_to_idx:
raise ValueError("Vocabulary must contain '<PAD>' key")
words = text.split()
return [word_to_idx.get(word, word_to_idx['<UNK>']) for word in words]
def generate_predictions_for_daf(
model: torch.nn.Module,
daf_text: str,
word_to_idx: dict,
label_encoder,
max_len: int = MAX_LEN
) -> list:
"""
Generate predictions for a single daf text (original text, not preprocessed).
Returns list of ranges: [{'start': int, 'end': int, 'type': int}, ...]
Positions are relative to the original text.
Strategy: Sliding window approach - predict on overlapping windows of text
"""
model.eval()
# Preprocess the text and get character mappings
preprocessed_text, prep_to_orig, orig_to_prep = preprocess_text(daf_text)
# Split into words and track character positions accurately
words = preprocessed_text.split()
if len(words) == 0:
return []
# Build word boundaries in preprocessed text by tracking positions as we iterate
# This is more reliable than using find() which could match wrong occurrences
word_boundaries = []
char_pos = 0
word_idx = 0
# Iterate through preprocessed text to find word boundaries
while char_pos < len(preprocessed_text) and word_idx < len(words):
# Skip leading spaces
while char_pos < len(preprocessed_text) and preprocessed_text[char_pos] == ' ':
char_pos += 1
if char_pos >= len(preprocessed_text):
break
# Find the current word
word = words[word_idx]
word_start = char_pos
# Check if the word starts at this position
if preprocessed_text[char_pos:char_pos + len(word)] == word:
word_end = char_pos + len(word)
word_boundaries.append((word_start, word_end))
char_pos = word_end
word_idx += 1
else:
# Word doesn't match - this shouldn't happen, but handle gracefully
# Try to find the word starting from current position
found_pos = preprocessed_text.find(word, char_pos)
if found_pos != -1:
word_boundaries.append((found_pos, found_pos + len(word)))
char_pos = found_pos + len(word)
word_idx += 1
else:
# Couldn't find word - this indicates a mismatch between words and preprocessed_text
# This can happen if preprocessing changed the text in an unexpected way
# Log a warning and use a fallback: estimate position based on character count
warnings.warn(f"Word '{word}' at index {word_idx} not found in preprocessed text. Using estimated position.")
# Estimate position: assume words are separated by single spaces
estimated_start = char_pos
estimated_end = estimated_start + len(word)
word_boundaries.append((estimated_start, min(estimated_end, len(preprocessed_text))))
char_pos = estimated_end
word_idx += 1
# Validate that we found boundaries for all words
if len(word_boundaries) < len(words):
warnings.warn(f"Only found boundaries for {len(word_boundaries)} out of {len(words)} words. Some predictions may be inaccurate.")
# Use sliding window approach
window_size = max_len
stride = window_size // 2 # 50% overlap
predictions = []
ranges = []
with torch.no_grad():
for i in range(0, len(words), stride):
# Get window of words
window_words = words[i:i + window_size]
if len(window_words) == 0:
break
# Convert to sequence
seq = text_to_sequence(' '.join(window_words), word_to_idx)
# Pad or truncate to max_len
if len(seq) > max_len:
seq = seq[:max_len]
else:
seq = seq + [word_to_idx['<PAD>']] * (max_len - len(seq))
# Convert to tensor and add batch dimension
seq_tensor = torch.tensor([seq], dtype=torch.long)
# Get prediction
output = model(seq_tensor)
_, predicted = torch.max(output.data, 1)
predicted_label_idx = predicted.item()
# Calculate character positions in preprocessed text using word boundaries
# Ensure we don't go out of bounds
if i >= len(word_boundaries):
continue
last_word_idx = min(i + len(window_words) - 1, len(word_boundaries) - 1)
if last_word_idx < i:
continue
# Start position is the start of the first word in the window
window_start_prep = word_boundaries[i][0]
# End position is the end of the last word in the window
window_end_prep = word_boundaries[last_word_idx][1]
# Only add if we have a valid range
if window_end_prep > window_start_prep:
# Map preprocessed text positions to original text positions
# Find the original start position
original_start = prep_to_orig.get(window_start_prep)
if original_start is None:
# Find the closest mapped position before or at window_start_prep
for prep_pos in sorted(prep_to_orig.keys(), reverse=True):
if prep_pos <= window_start_prep:
original_start = prep_to_orig[prep_pos]
break
if original_start is None:
continue # Skip if we can't map start position
# Find the original end position
# window_end_prep points to the character after the last character in the window
# We need to map this to the original text
window_end_prep_clamped = min(window_end_prep, len(preprocessed_text))
# Find the original position corresponding to the end of the window
# If window_end_prep_clamped is at the end of preprocessed text, use end of original text
if window_end_prep_clamped >= len(preprocessed_text):
original_end = len(daf_text)
else:
# Find the original position for the character at window_end_prep_clamped
# (the character right after the window ends)
end_char_orig = prep_to_orig.get(window_end_prep_clamped)
if end_char_orig is not None:
original_end = end_char_orig
else:
# Character at window_end_prep_clamped was removed, find the next non-removed character
# Look for the next preprocessed position >= window_end_prep_clamped
next_prep_pos = None
for prep_pos in sorted(prep_to_orig.keys()):
if prep_pos >= window_end_prep_clamped:
next_prep_pos = prep_pos
break
if next_prep_pos is not None:
original_end = prep_to_orig[next_prep_pos]
else:
# No more characters in preprocessed text, use end of original text
original_end = len(daf_text)
# Ensure end is after start and within bounds
if original_end <= original_start:
# Fallback: ensure at least one character
original_end = min(original_start + 1, len(daf_text))
original_end = min(original_end, len(daf_text))
ranges.append({
'start': original_start,
'end': original_end,
'type': int(predicted_label_idx)
})
# Merge overlapping ranges of the same type
if len(ranges) == 0:
return []
# Sort by start position
ranges.sort(key=lambda x: x['start'])
# Merge consecutive ranges of same type
merged_ranges = []
current_range = ranges[0].copy()
for next_range in ranges[1:]:
# If same type and overlapping or adjacent, merge
if (next_range['type'] == current_range['type'] and
next_range['start'] <= current_range['end']):
current_range['end'] = max(current_range['end'], next_range['end'])
else:
merged_ranges.append(current_range)
current_range = next_range.copy()
merged_ranges.append(current_range)
return merged_ranges
def generate_all_predictions(
model: torch.nn.Module,
word_to_idx: dict,
label_encoder,
vercel_base_url: str,
auth_token: str
) -> list:
"""
DEPRECATED: This function is no longer used in the training flow.
It's kept for reference but should not be called.
Generate predictions for all dafim.
Returns list of prediction objects: [{'daf_id': str, 'ranges': [...]}, ...]
NOTE: This function expects preprocessed text from the API, but generate_predictions_for_daf
now expects original text. This function needs to be updated if it's ever used again.
Args:
model: Trained model
word_to_idx: Word to index mapping
label_encoder: Label encoder
vercel_base_url: Base URL of the Vercel app
auth_token: Authentication token for Vercel API (TRAINING_CALLBACK_TOKEN)
"""
print("WARNING: generate_all_predictions is deprecated and may not work correctly.")
print("Fetching daf texts from Vercel...")
dafim = fetch_daf_texts(vercel_base_url, auth_token)
if len(dafim) == 0:
print("No dafim found")
return []
predictions = []
print(f"Generating predictions for {len(dafim)} dafim...")
for idx, daf in enumerate(dafim):
if (idx + 1) % 100 == 0:
print(f"Processed {idx + 1}/{len(dafim)} dafim...")
try:
daf_id = daf['id']
# NOTE: The API returns preprocessed text, but generate_predictions_for_daf
# now expects original text. This will cause incorrect character position mapping.
# This function should fetch original text or be updated to handle preprocessed text.
text_content = daf['text_content']
ranges = generate_predictions_for_daf(
model, text_content, word_to_idx, label_encoder
)
predictions.append({
'daf_id': daf_id,
'ranges': ranges
})
except Exception as e:
print(f"Error generating predictions for daf {daf.get('id', 'unknown')}: {e}")
# Continue with next daf
continue
print(f"Generated predictions for {len(predictions)} dafim")
return predictions
|