Spaces:
Sleeping
Sleeping
| """ | |
| Prediction generation module for Talmud language classifier | |
| Generates predictions for all dafim using a trained model | |
| """ | |
| import torch | |
| import requests | |
| import os | |
| import re | |
| import warnings | |
| from train import TalmudClassifierLSTM, TalmudDataset, MAX_LEN | |
| # Preprocessing regex to match Vercel's preprocessing exactly | |
| # Vercel uses: /[\u0591-\u05C7]|[,\-?!:\.״]+|<[^>]+>/g | |
| PREPROCESSING_REGEX = re.compile(r'[\u0591-\u05C7]|[,\-?!:\.״]+|<[^>]+>') | |
| def preprocess_text(text: str) -> tuple[str, dict, dict]: | |
| """ | |
| Preprocess text by removing nikud, punctuation, and HTML tags. | |
| Matches Vercel's preprocessing exactly. | |
| Returns (preprocessed_text, prep_to_orig, orig_to_prep) where: | |
| - prep_to_orig maps preprocessed position -> original position | |
| - orig_to_prep maps original position -> preprocessed position (or -1 if removed) | |
| """ | |
| preprocessed = '' | |
| prep_to_orig = {} # Maps preprocessed_pos -> original_pos | |
| orig_to_prep = {} # Maps original_pos -> preprocessed_pos (or -1 if removed) | |
| preprocessed_pos = 0 | |
| i = 0 | |
| # Process text character by character, handling HTML tags as units | |
| while i < len(text): | |
| # Check for HTML tags (they are removed as units) | |
| if text[i] == '<': | |
| # Find the end of the HTML tag | |
| tag_end = text.find('>', i) | |
| if tag_end != -1: | |
| # Mark all characters in the tag as removed | |
| for orig_pos in range(i, tag_end + 1): | |
| orig_to_prep[orig_pos] = -1 | |
| i = tag_end + 1 | |
| continue | |
| char = text[i] | |
| char_code = ord(char) | |
| # Check if character should be removed: | |
| # 1. Nikud range: \u0591-\u05C7 (0x0591 to 0x05C7) | |
| # 2. Punctuation: , - ? ! : . ״ | |
| should_remove = ( | |
| (0x0591 <= char_code <= 0x05C7) or | |
| char in [',', '-', '?', '!', ':', '.', '״'] | |
| ) | |
| if should_remove: | |
| orig_to_prep[i] = -1 # Mark as removed | |
| else: | |
| prep_to_orig[preprocessed_pos] = i | |
| orig_to_prep[i] = preprocessed_pos | |
| preprocessed += char | |
| preprocessed_pos += 1 | |
| i += 1 | |
| return preprocessed, prep_to_orig, orig_to_prep | |
| def fetch_daf_texts(vercel_base_url: str, auth_token: str) -> list: | |
| """ | |
| Fetch all daf texts from Vercel API endpoint. | |
| Returns list of daf objects with id and text_content. | |
| Args: | |
| vercel_base_url: Base URL of the Vercel app | |
| auth_token: Authentication token for Vercel API (TRAINING_CALLBACK_TOKEN) | |
| """ | |
| url = f"{vercel_base_url}/api/dafim-texts" | |
| print(f"Fetching daf texts from {url}...") | |
| try: | |
| # Include authentication token in header | |
| headers = { | |
| 'x-auth-token': auth_token, | |
| 'Content-Type': 'application/json' | |
| } | |
| response = requests.get(url, headers=headers, timeout=60) | |
| response.raise_for_status() | |
| data = response.json() | |
| print(f"Fetched {data.get('count', 0)} dafim") | |
| return data.get('dafim', []) | |
| except Exception as e: | |
| print(f"Error fetching daf texts: {e}") | |
| if hasattr(e, 'response') and e.response is not None: | |
| print(f"Response status: {e.response.status_code}") | |
| print(f"Response text: {e.response.text}") | |
| raise | |
| def text_to_sequence(text: str, word_to_idx: dict) -> list: | |
| """Convert text to sequence of word indices""" | |
| # Validate that required keys exist | |
| if '<UNK>' not in word_to_idx: | |
| raise ValueError("Vocabulary must contain '<UNK>' key") | |
| if '<PAD>' not in word_to_idx: | |
| raise ValueError("Vocabulary must contain '<PAD>' key") | |
| words = text.split() | |
| return [word_to_idx.get(word, word_to_idx['<UNK>']) for word in words] | |
| def generate_predictions_for_daf( | |
| model: torch.nn.Module, | |
| daf_text: str, | |
| word_to_idx: dict, | |
| label_encoder, | |
| max_len: int = MAX_LEN | |
| ) -> list: | |
| """ | |
| Generate predictions for a single daf text (original text, not preprocessed). | |
| Returns list of ranges: [{'start': int, 'end': int, 'type': int}, ...] | |
| Positions are relative to the original text. | |
| Strategy: Sliding window approach - predict on overlapping windows of text | |
| """ | |
| model.eval() | |
| # Preprocess the text and get character mappings | |
| preprocessed_text, prep_to_orig, orig_to_prep = preprocess_text(daf_text) | |
| # Split into words and track character positions accurately | |
| words = preprocessed_text.split() | |
| if len(words) == 0: | |
| return [] | |
| # Build word boundaries in preprocessed text by tracking positions as we iterate | |
| # This is more reliable than using find() which could match wrong occurrences | |
| word_boundaries = [] | |
| char_pos = 0 | |
| word_idx = 0 | |
| # Iterate through preprocessed text to find word boundaries | |
| while char_pos < len(preprocessed_text) and word_idx < len(words): | |
| # Skip leading spaces | |
| while char_pos < len(preprocessed_text) and preprocessed_text[char_pos] == ' ': | |
| char_pos += 1 | |
| if char_pos >= len(preprocessed_text): | |
| break | |
| # Find the current word | |
| word = words[word_idx] | |
| word_start = char_pos | |
| # Check if the word starts at this position | |
| if preprocessed_text[char_pos:char_pos + len(word)] == word: | |
| word_end = char_pos + len(word) | |
| word_boundaries.append((word_start, word_end)) | |
| char_pos = word_end | |
| word_idx += 1 | |
| else: | |
| # Word doesn't match - this shouldn't happen, but handle gracefully | |
| # Try to find the word starting from current position | |
| found_pos = preprocessed_text.find(word, char_pos) | |
| if found_pos != -1: | |
| word_boundaries.append((found_pos, found_pos + len(word))) | |
| char_pos = found_pos + len(word) | |
| word_idx += 1 | |
| else: | |
| # Couldn't find word - this indicates a mismatch between words and preprocessed_text | |
| # This can happen if preprocessing changed the text in an unexpected way | |
| # Log a warning and use a fallback: estimate position based on character count | |
| warnings.warn(f"Word '{word}' at index {word_idx} not found in preprocessed text. Using estimated position.") | |
| # Estimate position: assume words are separated by single spaces | |
| estimated_start = char_pos | |
| estimated_end = estimated_start + len(word) | |
| word_boundaries.append((estimated_start, min(estimated_end, len(preprocessed_text)))) | |
| char_pos = estimated_end | |
| word_idx += 1 | |
| # Validate that we found boundaries for all words | |
| if len(word_boundaries) < len(words): | |
| warnings.warn(f"Only found boundaries for {len(word_boundaries)} out of {len(words)} words. Some predictions may be inaccurate.") | |
| # Use sliding window approach | |
| window_size = max_len | |
| stride = window_size // 2 # 50% overlap | |
| predictions = [] | |
| ranges = [] | |
| with torch.no_grad(): | |
| for i in range(0, len(words), stride): | |
| # Get window of words | |
| window_words = words[i:i + window_size] | |
| if len(window_words) == 0: | |
| break | |
| # Convert to sequence | |
| seq = text_to_sequence(' '.join(window_words), word_to_idx) | |
| # Pad or truncate to max_len | |
| if len(seq) > max_len: | |
| seq = seq[:max_len] | |
| else: | |
| seq = seq + [word_to_idx['<PAD>']] * (max_len - len(seq)) | |
| # Convert to tensor and add batch dimension | |
| seq_tensor = torch.tensor([seq], dtype=torch.long) | |
| # Get prediction | |
| output = model(seq_tensor) | |
| _, predicted = torch.max(output.data, 1) | |
| predicted_label_idx = predicted.item() | |
| # Calculate character positions in preprocessed text using word boundaries | |
| # Ensure we don't go out of bounds | |
| if i >= len(word_boundaries): | |
| continue | |
| last_word_idx = min(i + len(window_words) - 1, len(word_boundaries) - 1) | |
| if last_word_idx < i: | |
| continue | |
| # Start position is the start of the first word in the window | |
| window_start_prep = word_boundaries[i][0] | |
| # End position is the end of the last word in the window | |
| window_end_prep = word_boundaries[last_word_idx][1] | |
| # Only add if we have a valid range | |
| if window_end_prep > window_start_prep: | |
| # Map preprocessed text positions to original text positions | |
| # Find the original start position | |
| original_start = prep_to_orig.get(window_start_prep) | |
| if original_start is None: | |
| # Find the closest mapped position before or at window_start_prep | |
| for prep_pos in sorted(prep_to_orig.keys(), reverse=True): | |
| if prep_pos <= window_start_prep: | |
| original_start = prep_to_orig[prep_pos] | |
| break | |
| if original_start is None: | |
| continue # Skip if we can't map start position | |
| # Find the original end position | |
| # window_end_prep points to the character after the last character in the window | |
| # We need to map this to the original text | |
| window_end_prep_clamped = min(window_end_prep, len(preprocessed_text)) | |
| # Find the original position corresponding to the end of the window | |
| # If window_end_prep_clamped is at the end of preprocessed text, use end of original text | |
| if window_end_prep_clamped >= len(preprocessed_text): | |
| original_end = len(daf_text) | |
| else: | |
| # Find the original position for the character at window_end_prep_clamped | |
| # (the character right after the window ends) | |
| end_char_orig = prep_to_orig.get(window_end_prep_clamped) | |
| if end_char_orig is not None: | |
| original_end = end_char_orig | |
| else: | |
| # Character at window_end_prep_clamped was removed, find the next non-removed character | |
| # Look for the next preprocessed position >= window_end_prep_clamped | |
| next_prep_pos = None | |
| for prep_pos in sorted(prep_to_orig.keys()): | |
| if prep_pos >= window_end_prep_clamped: | |
| next_prep_pos = prep_pos | |
| break | |
| if next_prep_pos is not None: | |
| original_end = prep_to_orig[next_prep_pos] | |
| else: | |
| # No more characters in preprocessed text, use end of original text | |
| original_end = len(daf_text) | |
| # Ensure end is after start and within bounds | |
| if original_end <= original_start: | |
| # Fallback: ensure at least one character | |
| original_end = min(original_start + 1, len(daf_text)) | |
| original_end = min(original_end, len(daf_text)) | |
| ranges.append({ | |
| 'start': original_start, | |
| 'end': original_end, | |
| 'type': int(predicted_label_idx) | |
| }) | |
| # Merge overlapping ranges of the same type | |
| if len(ranges) == 0: | |
| return [] | |
| # Sort by start position | |
| ranges.sort(key=lambda x: x['start']) | |
| # Merge consecutive ranges of same type | |
| merged_ranges = [] | |
| current_range = ranges[0].copy() | |
| for next_range in ranges[1:]: | |
| # If same type and overlapping or adjacent, merge | |
| if (next_range['type'] == current_range['type'] and | |
| next_range['start'] <= current_range['end']): | |
| current_range['end'] = max(current_range['end'], next_range['end']) | |
| else: | |
| merged_ranges.append(current_range) | |
| current_range = next_range.copy() | |
| merged_ranges.append(current_range) | |
| return merged_ranges | |
| def generate_all_predictions( | |
| model: torch.nn.Module, | |
| word_to_idx: dict, | |
| label_encoder, | |
| vercel_base_url: str, | |
| auth_token: str | |
| ) -> list: | |
| """ | |
| DEPRECATED: This function is no longer used in the training flow. | |
| It's kept for reference but should not be called. | |
| Generate predictions for all dafim. | |
| Returns list of prediction objects: [{'daf_id': str, 'ranges': [...]}, ...] | |
| NOTE: This function expects preprocessed text from the API, but generate_predictions_for_daf | |
| now expects original text. This function needs to be updated if it's ever used again. | |
| Args: | |
| model: Trained model | |
| word_to_idx: Word to index mapping | |
| label_encoder: Label encoder | |
| vercel_base_url: Base URL of the Vercel app | |
| auth_token: Authentication token for Vercel API (TRAINING_CALLBACK_TOKEN) | |
| """ | |
| print("WARNING: generate_all_predictions is deprecated and may not work correctly.") | |
| print("Fetching daf texts from Vercel...") | |
| dafim = fetch_daf_texts(vercel_base_url, auth_token) | |
| if len(dafim) == 0: | |
| print("No dafim found") | |
| return [] | |
| predictions = [] | |
| print(f"Generating predictions for {len(dafim)} dafim...") | |
| for idx, daf in enumerate(dafim): | |
| if (idx + 1) % 100 == 0: | |
| print(f"Processed {idx + 1}/{len(dafim)} dafim...") | |
| try: | |
| daf_id = daf['id'] | |
| # NOTE: The API returns preprocessed text, but generate_predictions_for_daf | |
| # now expects original text. This will cause incorrect character position mapping. | |
| # This function should fetch original text or be updated to handle preprocessed text. | |
| text_content = daf['text_content'] | |
| ranges = generate_predictions_for_daf( | |
| model, text_content, word_to_idx, label_encoder | |
| ) | |
| predictions.append({ | |
| 'daf_id': daf_id, | |
| 'ranges': ranges | |
| }) | |
| except Exception as e: | |
| print(f"Error generating predictions for daf {daf.get('id', 'unknown')}: {e}") | |
| # Continue with next daf | |
| continue | |
| print(f"Generated predictions for {len(predictions)} dafim") | |
| return predictions | |