shelfgot's picture
Update predict.py
be1bc6c verified
"""
Prediction generation module for Talmud language classifier
Generates predictions for all dafim using a trained model
"""
import torch
import requests
import os
import re
import warnings
from train import TalmudClassifierLSTM, TalmudDataset, MAX_LEN
# Preprocessing regex to match Vercel's preprocessing exactly
# Vercel uses: /[\u0591-\u05C7]|[,\-?!:\.״]+|<[^>]+>/g
PREPROCESSING_REGEX = re.compile(r'[\u0591-\u05C7]|[,\-?!:\.״]+|<[^>]+>')
def preprocess_text(text: str) -> tuple[str, dict, dict]:
"""
Preprocess text by removing nikud, punctuation, and HTML tags.
Matches Vercel's preprocessing exactly.
Returns (preprocessed_text, prep_to_orig, orig_to_prep) where:
- prep_to_orig maps preprocessed position -> original position
- orig_to_prep maps original position -> preprocessed position (or -1 if removed)
"""
preprocessed = ''
prep_to_orig = {} # Maps preprocessed_pos -> original_pos
orig_to_prep = {} # Maps original_pos -> preprocessed_pos (or -1 if removed)
preprocessed_pos = 0
i = 0
# Process text character by character, handling HTML tags as units
while i < len(text):
# Check for HTML tags (they are removed as units)
if text[i] == '<':
# Find the end of the HTML tag
tag_end = text.find('>', i)
if tag_end != -1:
# Mark all characters in the tag as removed
for orig_pos in range(i, tag_end + 1):
orig_to_prep[orig_pos] = -1
i = tag_end + 1
continue
char = text[i]
char_code = ord(char)
# Check if character should be removed:
# 1. Nikud range: \u0591-\u05C7 (0x0591 to 0x05C7)
# 2. Punctuation: , - ? ! : . ״
should_remove = (
(0x0591 <= char_code <= 0x05C7) or
char in [',', '-', '?', '!', ':', '.', '״']
)
if should_remove:
orig_to_prep[i] = -1 # Mark as removed
else:
prep_to_orig[preprocessed_pos] = i
orig_to_prep[i] = preprocessed_pos
preprocessed += char
preprocessed_pos += 1
i += 1
return preprocessed, prep_to_orig, orig_to_prep
def fetch_daf_texts(vercel_base_url: str, auth_token: str) -> list:
"""
Fetch all daf texts from Vercel API endpoint.
Returns list of daf objects with id and text_content.
Args:
vercel_base_url: Base URL of the Vercel app
auth_token: Authentication token for Vercel API (TRAINING_CALLBACK_TOKEN)
"""
url = f"{vercel_base_url}/api/dafim-texts"
print(f"Fetching daf texts from {url}...")
try:
# Include authentication token in header
headers = {
'x-auth-token': auth_token,
'Content-Type': 'application/json'
}
response = requests.get(url, headers=headers, timeout=60)
response.raise_for_status()
data = response.json()
print(f"Fetched {data.get('count', 0)} dafim")
return data.get('dafim', [])
except Exception as e:
print(f"Error fetching daf texts: {e}")
if hasattr(e, 'response') and e.response is not None:
print(f"Response status: {e.response.status_code}")
print(f"Response text: {e.response.text}")
raise
def text_to_sequence(text: str, word_to_idx: dict) -> list:
"""Convert text to sequence of word indices"""
# Validate that required keys exist
if '<UNK>' not in word_to_idx:
raise ValueError("Vocabulary must contain '<UNK>' key")
if '<PAD>' not in word_to_idx:
raise ValueError("Vocabulary must contain '<PAD>' key")
words = text.split()
return [word_to_idx.get(word, word_to_idx['<UNK>']) for word in words]
def generate_predictions_for_daf(
model: torch.nn.Module,
daf_text: str,
word_to_idx: dict,
label_encoder,
max_len: int = MAX_LEN
) -> list:
"""
Generate predictions for a single daf text (original text, not preprocessed).
Returns list of ranges: [{'start': int, 'end': int, 'type': int}, ...]
Positions are relative to the original text.
Strategy: Sliding window approach - predict on overlapping windows of text
"""
model.eval()
# Preprocess the text and get character mappings
preprocessed_text, prep_to_orig, orig_to_prep = preprocess_text(daf_text)
# Split into words and track character positions accurately
words = preprocessed_text.split()
if len(words) == 0:
return []
# Build word boundaries in preprocessed text by tracking positions as we iterate
# This is more reliable than using find() which could match wrong occurrences
word_boundaries = []
char_pos = 0
word_idx = 0
# Iterate through preprocessed text to find word boundaries
while char_pos < len(preprocessed_text) and word_idx < len(words):
# Skip leading spaces
while char_pos < len(preprocessed_text) and preprocessed_text[char_pos] == ' ':
char_pos += 1
if char_pos >= len(preprocessed_text):
break
# Find the current word
word = words[word_idx]
word_start = char_pos
# Check if the word starts at this position
if preprocessed_text[char_pos:char_pos + len(word)] == word:
word_end = char_pos + len(word)
word_boundaries.append((word_start, word_end))
char_pos = word_end
word_idx += 1
else:
# Word doesn't match - this shouldn't happen, but handle gracefully
# Try to find the word starting from current position
found_pos = preprocessed_text.find(word, char_pos)
if found_pos != -1:
word_boundaries.append((found_pos, found_pos + len(word)))
char_pos = found_pos + len(word)
word_idx += 1
else:
# Couldn't find word - this indicates a mismatch between words and preprocessed_text
# This can happen if preprocessing changed the text in an unexpected way
# Log a warning and use a fallback: estimate position based on character count
warnings.warn(f"Word '{word}' at index {word_idx} not found in preprocessed text. Using estimated position.")
# Estimate position: assume words are separated by single spaces
estimated_start = char_pos
estimated_end = estimated_start + len(word)
word_boundaries.append((estimated_start, min(estimated_end, len(preprocessed_text))))
char_pos = estimated_end
word_idx += 1
# Validate that we found boundaries for all words
if len(word_boundaries) < len(words):
warnings.warn(f"Only found boundaries for {len(word_boundaries)} out of {len(words)} words. Some predictions may be inaccurate.")
# Use sliding window approach
window_size = max_len
stride = window_size // 2 # 50% overlap
predictions = []
ranges = []
with torch.no_grad():
for i in range(0, len(words), stride):
# Get window of words
window_words = words[i:i + window_size]
if len(window_words) == 0:
break
# Convert to sequence
seq = text_to_sequence(' '.join(window_words), word_to_idx)
# Pad or truncate to max_len
if len(seq) > max_len:
seq = seq[:max_len]
else:
seq = seq + [word_to_idx['<PAD>']] * (max_len - len(seq))
# Convert to tensor and add batch dimension
seq_tensor = torch.tensor([seq], dtype=torch.long)
# Get prediction
output = model(seq_tensor)
_, predicted = torch.max(output.data, 1)
predicted_label_idx = predicted.item()
# Calculate character positions in preprocessed text using word boundaries
# Ensure we don't go out of bounds
if i >= len(word_boundaries):
continue
last_word_idx = min(i + len(window_words) - 1, len(word_boundaries) - 1)
if last_word_idx < i:
continue
# Start position is the start of the first word in the window
window_start_prep = word_boundaries[i][0]
# End position is the end of the last word in the window
window_end_prep = word_boundaries[last_word_idx][1]
# Only add if we have a valid range
if window_end_prep > window_start_prep:
# Map preprocessed text positions to original text positions
# Find the original start position
original_start = prep_to_orig.get(window_start_prep)
if original_start is None:
# Find the closest mapped position before or at window_start_prep
for prep_pos in sorted(prep_to_orig.keys(), reverse=True):
if prep_pos <= window_start_prep:
original_start = prep_to_orig[prep_pos]
break
if original_start is None:
continue # Skip if we can't map start position
# Find the original end position
# window_end_prep points to the character after the last character in the window
# We need to map this to the original text
window_end_prep_clamped = min(window_end_prep, len(preprocessed_text))
# Find the original position corresponding to the end of the window
# If window_end_prep_clamped is at the end of preprocessed text, use end of original text
if window_end_prep_clamped >= len(preprocessed_text):
original_end = len(daf_text)
else:
# Find the original position for the character at window_end_prep_clamped
# (the character right after the window ends)
end_char_orig = prep_to_orig.get(window_end_prep_clamped)
if end_char_orig is not None:
original_end = end_char_orig
else:
# Character at window_end_prep_clamped was removed, find the next non-removed character
# Look for the next preprocessed position >= window_end_prep_clamped
next_prep_pos = None
for prep_pos in sorted(prep_to_orig.keys()):
if prep_pos >= window_end_prep_clamped:
next_prep_pos = prep_pos
break
if next_prep_pos is not None:
original_end = prep_to_orig[next_prep_pos]
else:
# No more characters in preprocessed text, use end of original text
original_end = len(daf_text)
# Ensure end is after start and within bounds
if original_end <= original_start:
# Fallback: ensure at least one character
original_end = min(original_start + 1, len(daf_text))
original_end = min(original_end, len(daf_text))
ranges.append({
'start': original_start,
'end': original_end,
'type': int(predicted_label_idx)
})
# Merge overlapping ranges of the same type
if len(ranges) == 0:
return []
# Sort by start position
ranges.sort(key=lambda x: x['start'])
# Merge consecutive ranges of same type
merged_ranges = []
current_range = ranges[0].copy()
for next_range in ranges[1:]:
# If same type and overlapping or adjacent, merge
if (next_range['type'] == current_range['type'] and
next_range['start'] <= current_range['end']):
current_range['end'] = max(current_range['end'], next_range['end'])
else:
merged_ranges.append(current_range)
current_range = next_range.copy()
merged_ranges.append(current_range)
return merged_ranges
def generate_all_predictions(
model: torch.nn.Module,
word_to_idx: dict,
label_encoder,
vercel_base_url: str,
auth_token: str
) -> list:
"""
DEPRECATED: This function is no longer used in the training flow.
It's kept for reference but should not be called.
Generate predictions for all dafim.
Returns list of prediction objects: [{'daf_id': str, 'ranges': [...]}, ...]
NOTE: This function expects preprocessed text from the API, but generate_predictions_for_daf
now expects original text. This function needs to be updated if it's ever used again.
Args:
model: Trained model
word_to_idx: Word to index mapping
label_encoder: Label encoder
vercel_base_url: Base URL of the Vercel app
auth_token: Authentication token for Vercel API (TRAINING_CALLBACK_TOKEN)
"""
print("WARNING: generate_all_predictions is deprecated and may not work correctly.")
print("Fetching daf texts from Vercel...")
dafim = fetch_daf_texts(vercel_base_url, auth_token)
if len(dafim) == 0:
print("No dafim found")
return []
predictions = []
print(f"Generating predictions for {len(dafim)} dafim...")
for idx, daf in enumerate(dafim):
if (idx + 1) % 100 == 0:
print(f"Processed {idx + 1}/{len(dafim)} dafim...")
try:
daf_id = daf['id']
# NOTE: The API returns preprocessed text, but generate_predictions_for_daf
# now expects original text. This will cause incorrect character position mapping.
# This function should fetch original text or be updated to handle preprocessed text.
text_content = daf['text_content']
ranges = generate_predictions_for_daf(
model, text_content, word_to_idx, label_encoder
)
predictions.append({
'daf_id': daf_id,
'ranges': ranges
})
except Exception as e:
print(f"Error generating predictions for daf {daf.get('id', 'unknown')}: {e}")
# Continue with next daf
continue
print(f"Generated predictions for {len(predictions)} dafim")
return predictions