Spaces:
Sleeping
Sleeping
Update predict.py
Browse files- predict.py +21 -2
predict.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch
|
|
| 7 |
import requests
|
| 8 |
import os
|
| 9 |
import re
|
|
|
|
| 10 |
from train import TalmudClassifierLSTM, TalmudDataset, MAX_LEN
|
| 11 |
|
| 12 |
# Preprocessing regex to match Vercel's preprocessing exactly
|
|
@@ -94,6 +95,12 @@ def fetch_daf_texts(vercel_base_url: str, auth_token: str) -> list:
|
|
| 94 |
|
| 95 |
def text_to_sequence(text: str, word_to_idx: dict) -> list:
|
| 96 |
"""Convert text to sequence of word indices"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
words = text.split()
|
| 98 |
return [word_to_idx.get(word, word_to_idx['<UNK>']) for word in words]
|
| 99 |
|
|
@@ -156,8 +163,20 @@ def generate_predictions_for_daf(
|
|
| 156 |
char_pos = found_pos + len(word)
|
| 157 |
word_idx += 1
|
| 158 |
else:
|
| 159 |
-
#
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# Use sliding window approach
|
| 163 |
window_size = max_len
|
|
|
|
| 7 |
import requests
|
| 8 |
import os
|
| 9 |
import re
|
| 10 |
+
import warnings
|
| 11 |
from train import TalmudClassifierLSTM, TalmudDataset, MAX_LEN
|
| 12 |
|
| 13 |
# Preprocessing regex to match Vercel's preprocessing exactly
|
|
|
|
| 95 |
|
| 96 |
def text_to_sequence(text: str, word_to_idx: dict) -> list:
|
| 97 |
"""Convert text to sequence of word indices"""
|
| 98 |
+
# Validate that required keys exist
|
| 99 |
+
if '<UNK>' not in word_to_idx:
|
| 100 |
+
raise ValueError("Vocabulary must contain '<UNK>' key")
|
| 101 |
+
if '<PAD>' not in word_to_idx:
|
| 102 |
+
raise ValueError("Vocabulary must contain '<PAD>' key")
|
| 103 |
+
|
| 104 |
words = text.split()
|
| 105 |
return [word_to_idx.get(word, word_to_idx['<UNK>']) for word in words]
|
| 106 |
|
|
|
|
| 163 |
char_pos = found_pos + len(word)
|
| 164 |
word_idx += 1
|
| 165 |
else:
|
| 166 |
+
# Couldn't find word - this indicates a mismatch between words and preprocessed_text
|
| 167 |
+
# This can happen if preprocessing changed the text in an unexpected way
|
| 168 |
+
# Log a warning and use a fallback: estimate position based on character count
|
| 169 |
+
warnings.warn(f"Word '{word}' at index {word_idx} not found in preprocessed text. Using estimated position.")
|
| 170 |
+
# Estimate position: assume words are separated by single spaces
|
| 171 |
+
estimated_start = char_pos
|
| 172 |
+
estimated_end = estimated_start + len(word)
|
| 173 |
+
word_boundaries.append((estimated_start, min(estimated_end, len(preprocessed_text))))
|
| 174 |
+
char_pos = estimated_end
|
| 175 |
+
word_idx += 1
|
| 176 |
+
|
| 177 |
+
# Validate that we found boundaries for all words
|
| 178 |
+
if len(word_boundaries) < len(words):
|
| 179 |
+
warnings.warn(f"Only found boundaries for {len(word_boundaries)} out of {len(words)} words. Some predictions may be inaccurate.")
|
| 180 |
|
| 181 |
# Use sliding window approach
|
| 182 |
window_size = max_len
|