""" Prediction generation module for Talmud language classifier Generates predictions for all dafim using a trained model """ import torch import requests import os import re import warnings from train import TalmudClassifierLSTM, TalmudDataset, MAX_LEN # Preprocessing regex to match Vercel's preprocessing exactly # Vercel uses: /[\u0591-\u05C7]|[,\-?!:\.״]+|<[^>]+>/g PREPROCESSING_REGEX = re.compile(r'[\u0591-\u05C7]|[,\-?!:\.״]+|<[^>]+>') def preprocess_text(text: str) -> tuple[str, dict, dict]: """ Preprocess text by removing nikud, punctuation, and HTML tags. Matches Vercel's preprocessing exactly. Returns (preprocessed_text, prep_to_orig, orig_to_prep) where: - prep_to_orig maps preprocessed position -> original position - orig_to_prep maps original position -> preprocessed position (or -1 if removed) """ preprocessed = '' prep_to_orig = {} # Maps preprocessed_pos -> original_pos orig_to_prep = {} # Maps original_pos -> preprocessed_pos (or -1 if removed) preprocessed_pos = 0 i = 0 # Process text character by character, handling HTML tags as units while i < len(text): # Check for HTML tags (they are removed as units) if text[i] == '<': # Find the end of the HTML tag tag_end = text.find('>', i) if tag_end != -1: # Mark all characters in the tag as removed for orig_pos in range(i, tag_end + 1): orig_to_prep[orig_pos] = -1 i = tag_end + 1 continue char = text[i] char_code = ord(char) # Check if character should be removed: # 1. Nikud range: \u0591-\u05C7 (0x0591 to 0x05C7) # 2. Punctuation: , - ? ! : . ״ should_remove = ( (0x0591 <= char_code <= 0x05C7) or char in [',', '-', '?', '!', ':', '.', '״'] ) if should_remove: orig_to_prep[i] = -1 # Mark as removed else: prep_to_orig[preprocessed_pos] = i orig_to_prep[i] = preprocessed_pos preprocessed += char preprocessed_pos += 1 i += 1 return preprocessed, prep_to_orig, orig_to_prep def fetch_daf_texts(vercel_base_url: str, auth_token: str) -> list: """ Fetch all daf texts from Vercel API endpoint. Returns list of daf objects with id and text_content. Args: vercel_base_url: Base URL of the Vercel app auth_token: Authentication token for Vercel API (TRAINING_CALLBACK_TOKEN) """ url = f"{vercel_base_url}/api/dafim-texts" print(f"Fetching daf texts from {url}...") try: # Include authentication token in header headers = { 'x-auth-token': auth_token, 'Content-Type': 'application/json' } response = requests.get(url, headers=headers, timeout=60) response.raise_for_status() data = response.json() print(f"Fetched {data.get('count', 0)} dafim") return data.get('dafim', []) except Exception as e: print(f"Error fetching daf texts: {e}") if hasattr(e, 'response') and e.response is not None: print(f"Response status: {e.response.status_code}") print(f"Response text: {e.response.text}") raise def text_to_sequence(text: str, word_to_idx: dict) -> list: """Convert text to sequence of word indices""" # Validate that required keys exist if '' not in word_to_idx: raise ValueError("Vocabulary must contain '' key") if '' not in word_to_idx: raise ValueError("Vocabulary must contain '' key") words = text.split() return [word_to_idx.get(word, word_to_idx['']) for word in words] def generate_predictions_for_daf( model: torch.nn.Module, daf_text: str, word_to_idx: dict, label_encoder, max_len: int = MAX_LEN ) -> list: """ Generate predictions for a single daf text (original text, not preprocessed). Returns list of ranges: [{'start': int, 'end': int, 'type': int}, ...] Positions are relative to the original text. Strategy: Sliding window approach - predict on overlapping windows of text """ model.eval() # Preprocess the text and get character mappings preprocessed_text, prep_to_orig, orig_to_prep = preprocess_text(daf_text) # Split into words and track character positions accurately words = preprocessed_text.split() if len(words) == 0: return [] # Build word boundaries in preprocessed text by tracking positions as we iterate # This is more reliable than using find() which could match wrong occurrences word_boundaries = [] char_pos = 0 word_idx = 0 # Iterate through preprocessed text to find word boundaries while char_pos < len(preprocessed_text) and word_idx < len(words): # Skip leading spaces while char_pos < len(preprocessed_text) and preprocessed_text[char_pos] == ' ': char_pos += 1 if char_pos >= len(preprocessed_text): break # Find the current word word = words[word_idx] word_start = char_pos # Check if the word starts at this position if preprocessed_text[char_pos:char_pos + len(word)] == word: word_end = char_pos + len(word) word_boundaries.append((word_start, word_end)) char_pos = word_end word_idx += 1 else: # Word doesn't match - this shouldn't happen, but handle gracefully # Try to find the word starting from current position found_pos = preprocessed_text.find(word, char_pos) if found_pos != -1: word_boundaries.append((found_pos, found_pos + len(word))) char_pos = found_pos + len(word) word_idx += 1 else: # Couldn't find word - this indicates a mismatch between words and preprocessed_text # This can happen if preprocessing changed the text in an unexpected way # Log a warning and use a fallback: estimate position based on character count warnings.warn(f"Word '{word}' at index {word_idx} not found in preprocessed text. Using estimated position.") # Estimate position: assume words are separated by single spaces estimated_start = char_pos estimated_end = estimated_start + len(word) word_boundaries.append((estimated_start, min(estimated_end, len(preprocessed_text)))) char_pos = estimated_end word_idx += 1 # Validate that we found boundaries for all words if len(word_boundaries) < len(words): warnings.warn(f"Only found boundaries for {len(word_boundaries)} out of {len(words)} words. Some predictions may be inaccurate.") # Use sliding window approach window_size = max_len stride = window_size // 2 # 50% overlap predictions = [] ranges = [] with torch.no_grad(): for i in range(0, len(words), stride): # Get window of words window_words = words[i:i + window_size] if len(window_words) == 0: break # Convert to sequence seq = text_to_sequence(' '.join(window_words), word_to_idx) # Pad or truncate to max_len if len(seq) > max_len: seq = seq[:max_len] else: seq = seq + [word_to_idx['']] * (max_len - len(seq)) # Convert to tensor and add batch dimension seq_tensor = torch.tensor([seq], dtype=torch.long) # Get prediction output = model(seq_tensor) _, predicted = torch.max(output.data, 1) predicted_label_idx = predicted.item() # Calculate character positions in preprocessed text using word boundaries # Ensure we don't go out of bounds if i >= len(word_boundaries): continue last_word_idx = min(i + len(window_words) - 1, len(word_boundaries) - 1) if last_word_idx < i: continue # Start position is the start of the first word in the window window_start_prep = word_boundaries[i][0] # End position is the end of the last word in the window window_end_prep = word_boundaries[last_word_idx][1] # Only add if we have a valid range if window_end_prep > window_start_prep: # Map preprocessed text positions to original text positions # Find the original start position original_start = prep_to_orig.get(window_start_prep) if original_start is None: # Find the closest mapped position before or at window_start_prep for prep_pos in sorted(prep_to_orig.keys(), reverse=True): if prep_pos <= window_start_prep: original_start = prep_to_orig[prep_pos] break if original_start is None: continue # Skip if we can't map start position # Find the original end position # window_end_prep points to the character after the last character in the window # We need to map this to the original text window_end_prep_clamped = min(window_end_prep, len(preprocessed_text)) # Find the original position corresponding to the end of the window # If window_end_prep_clamped is at the end of preprocessed text, use end of original text if window_end_prep_clamped >= len(preprocessed_text): original_end = len(daf_text) else: # Find the original position for the character at window_end_prep_clamped # (the character right after the window ends) end_char_orig = prep_to_orig.get(window_end_prep_clamped) if end_char_orig is not None: original_end = end_char_orig else: # Character at window_end_prep_clamped was removed, find the next non-removed character # Look for the next preprocessed position >= window_end_prep_clamped next_prep_pos = None for prep_pos in sorted(prep_to_orig.keys()): if prep_pos >= window_end_prep_clamped: next_prep_pos = prep_pos break if next_prep_pos is not None: original_end = prep_to_orig[next_prep_pos] else: # No more characters in preprocessed text, use end of original text original_end = len(daf_text) # Ensure end is after start and within bounds if original_end <= original_start: # Fallback: ensure at least one character original_end = min(original_start + 1, len(daf_text)) original_end = min(original_end, len(daf_text)) ranges.append({ 'start': original_start, 'end': original_end, 'type': int(predicted_label_idx) }) # Merge overlapping ranges of the same type if len(ranges) == 0: return [] # Sort by start position ranges.sort(key=lambda x: x['start']) # Merge consecutive ranges of same type merged_ranges = [] current_range = ranges[0].copy() for next_range in ranges[1:]: # If same type and overlapping or adjacent, merge if (next_range['type'] == current_range['type'] and next_range['start'] <= current_range['end']): current_range['end'] = max(current_range['end'], next_range['end']) else: merged_ranges.append(current_range) current_range = next_range.copy() merged_ranges.append(current_range) return merged_ranges def generate_all_predictions( model: torch.nn.Module, word_to_idx: dict, label_encoder, vercel_base_url: str, auth_token: str ) -> list: """ DEPRECATED: This function is no longer used in the training flow. It's kept for reference but should not be called. Generate predictions for all dafim. Returns list of prediction objects: [{'daf_id': str, 'ranges': [...]}, ...] NOTE: This function expects preprocessed text from the API, but generate_predictions_for_daf now expects original text. This function needs to be updated if it's ever used again. Args: model: Trained model word_to_idx: Word to index mapping label_encoder: Label encoder vercel_base_url: Base URL of the Vercel app auth_token: Authentication token for Vercel API (TRAINING_CALLBACK_TOKEN) """ print("WARNING: generate_all_predictions is deprecated and may not work correctly.") print("Fetching daf texts from Vercel...") dafim = fetch_daf_texts(vercel_base_url, auth_token) if len(dafim) == 0: print("No dafim found") return [] predictions = [] print(f"Generating predictions for {len(dafim)} dafim...") for idx, daf in enumerate(dafim): if (idx + 1) % 100 == 0: print(f"Processed {idx + 1}/{len(dafim)} dafim...") try: daf_id = daf['id'] # NOTE: The API returns preprocessed text, but generate_predictions_for_daf # now expects original text. This will cause incorrect character position mapping. # This function should fetch original text or be updated to handle preprocessed text. text_content = daf['text_content'] ranges = generate_predictions_for_daf( model, text_content, word_to_idx, label_encoder ) predictions.append({ 'daf_id': daf_id, 'ranges': ranges }) except Exception as e: print(f"Error generating predictions for daf {daf.get('id', 'unknown')}: {e}") # Continue with next daf continue print(f"Generated predictions for {len(predictions)} dafim") return predictions