shelfgot commited on
Commit
8176e08
·
verified ·
1 Parent(s): 172b660

single predictions

Browse files
Files changed (1) hide show
  1. predict.py +171 -53
predict.py CHANGED
@@ -6,15 +6,61 @@ Generates predictions for all dafim using a trained model
6
  import torch
7
  import requests
8
  import os
 
9
  from train import TalmudClassifierLSTM, TalmudDataset, MAX_LEN
10
 
11
- # Preprocessing regex to match Vercel's preprocessing
12
- PREPROCESSING_REGEX = r'[\u0591-\u05C7]|[,\-?!:\.״]+|<big><strong>|<\/strong><\/big>'
13
- import re
14
 
15
- def preprocess_text(text: str) -> str:
16
- """Preprocess text by removing nikud and punctuation"""
17
- return re.sub(PREPROCESSING_REGEX, '', text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def fetch_daf_texts(vercel_base_url: str, auth_token: str) -> list:
20
  """
@@ -29,38 +75,21 @@ def fetch_daf_texts(vercel_base_url: str, auth_token: str) -> list:
29
  print(f"Fetching daf texts from {url}...")
30
 
31
  try:
 
32
  headers = {
33
  'x-auth-token': auth_token,
34
  'Content-Type': 'application/json'
35
  }
36
-
37
-
38
- vercel_bypass_token = os.getenv('VERCEL_BYPASS_TOKEN')
39
- if vercel_bypass_token:
40
-
41
- separator = '&' if '?' in url else '?'
42
- url = f"{url}{separator}x-vercel-set-bypass-cookie=true&x-vercel-protection-bypass={vercel_bypass_token}"
43
- print(f"Using Vercel bypass token for deployment protection")
44
-
45
  response = requests.get(url, headers=headers, timeout=60)
46
  response.raise_for_status()
47
  data = response.json()
48
  print(f"Fetched {data.get('count', 0)} dafim")
49
  return data.get('dafim', [])
50
- except requests.exceptions.HTTPError as e:
51
- print(f"HTTP Error fetching daf texts: {e}")
52
- if hasattr(e, 'response') and e.response is not None:
53
- print(f"Response status: {e.response.status_code}")
54
- # Print first 500 chars of response for debugging
55
- response_text = e.response.text[:500] if e.response.text else "No response text"
56
- print(f"Response text (first 500 chars): {response_text}")
57
- # Check if it's a deployment protection issue
58
- if e.response.status_code == 401 and 'Authentication Required' in response_text:
59
- print("ERROR: Deployment protection is blocking the request.")
60
- print("Make sure VERCEL_BYPASS_TOKEN is set correctly in HF Space environment variables.")
61
- raise
62
  except Exception as e:
63
  print(f"Error fetching daf texts: {e}")
 
 
 
64
  raise
65
 
66
  def text_to_sequence(text: str, word_to_idx: dict) -> list:
@@ -76,22 +105,60 @@ def generate_predictions_for_daf(
76
  max_len: int = MAX_LEN
77
  ) -> list:
78
  """
79
- Generate predictions for a single daf text.
80
  Returns list of ranges: [{'start': int, 'end': int, 'type': int}, ...]
 
81
 
82
  Strategy: Sliding window approach - predict on overlapping windows of text
83
  """
84
  model.eval()
85
 
86
- # Preprocess the text (should already be preprocessed, but be safe)
87
- preprocessed_text = preprocess_text(daf_text)
88
 
89
- # Split into words
90
  words = preprocessed_text.split()
91
 
92
  if len(words) == 0:
93
  return []
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  # Use sliding window approach
96
  window_size = max_len
97
  stride = window_size // 2 # 50% overlap
@@ -124,32 +191,73 @@ def generate_predictions_for_daf(
124
  _, predicted = torch.max(output.data, 1)
125
  predicted_label_idx = predicted.item()
126
 
127
- # Calculate character positions in original text
128
- # Find the start position of this window in the original text
129
- window_text = ' '.join(window_words)
130
-
131
- # Find start position by searching in original text
132
- search_start = 0
133
- if i > 0:
134
- # Approximate position based on previous windows
135
- search_start = len(' '.join(words[:i]))
136
 
137
- # Find actual position in preprocessed text
138
- window_start_char = preprocessed_text.find(window_text, search_start)
 
139
 
140
- if window_start_char == -1:
141
- # Fallback: estimate position
142
- window_start_char = len(' '.join(words[:i])) if i > 0 else 0
143
-
144
- # Use the most confident prediction for the window center
145
- # For simplicity, predict the entire window as the predicted class
146
- window_end_char = window_start_char + len(window_text)
147
 
148
  # Only add if we have a valid range
149
- if window_end_char > window_start_char:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  ranges.append({
151
- 'start': window_start_char,
152
- 'end': window_end_char,
153
  'type': int(predicted_label_idx)
154
  })
155
 
@@ -185,9 +293,15 @@ def generate_all_predictions(
185
  auth_token: str
186
  ) -> list:
187
  """
 
 
 
188
  Generate predictions for all dafim.
189
  Returns list of prediction objects: [{'daf_id': str, 'ranges': [...]}, ...]
190
 
 
 
 
191
  Args:
192
  model: Trained model
193
  word_to_idx: Word to index mapping
@@ -195,6 +309,7 @@ def generate_all_predictions(
195
  vercel_base_url: Base URL of the Vercel app
196
  auth_token: Authentication token for Vercel API (TRAINING_CALLBACK_TOKEN)
197
  """
 
198
  print("Fetching daf texts from Vercel...")
199
  dafim = fetch_daf_texts(vercel_base_url, auth_token)
200
 
@@ -212,7 +327,10 @@ def generate_all_predictions(
212
 
213
  try:
214
  daf_id = daf['id']
215
- text_content = daf['text_content'] # Already preprocessed from API
 
 
 
216
 
217
  ranges = generate_predictions_for_daf(
218
  model, text_content, word_to_idx, label_encoder
 
6
  import torch
7
  import requests
8
  import os
9
+ import re
10
  from train import TalmudClassifierLSTM, TalmudDataset, MAX_LEN
11
 
12
+ # Preprocessing regex to match Vercel's preprocessing exactly
13
+ # Vercel uses: /[\u0591-\u05C7]|[,\-?!:\.״]+|<[^>]+>/g
14
+ PREPROCESSING_REGEX = re.compile(r'[\u0591-\u05C7]|[,\-?!:\.״]+|<[^>]+>')
15
 
16
+ def preprocess_text(text: str) -> tuple[str, dict, dict]:
17
+ """
18
+ Preprocess text by removing nikud, punctuation, and HTML tags.
19
+ Matches Vercel's preprocessing exactly.
20
+ Returns (preprocessed_text, prep_to_orig, orig_to_prep) where:
21
+ - prep_to_orig maps preprocessed position -> original position
22
+ - orig_to_prep maps original position -> preprocessed position (or -1 if removed)
23
+ """
24
+ preprocessed = ''
25
+ prep_to_orig = {} # Maps preprocessed_pos -> original_pos
26
+ orig_to_prep = {} # Maps original_pos -> preprocessed_pos (or -1 if removed)
27
+ preprocessed_pos = 0
28
+ i = 0
29
+
30
+ # Process text character by character, handling HTML tags as units
31
+ while i < len(text):
32
+ # Check for HTML tags (they are removed as units)
33
+ if text[i] == '<':
34
+ # Find the end of the HTML tag
35
+ tag_end = text.find('>', i)
36
+ if tag_end != -1:
37
+ # Mark all characters in the tag as removed
38
+ for orig_pos in range(i, tag_end + 1):
39
+ orig_to_prep[orig_pos] = -1
40
+ i = tag_end + 1
41
+ continue
42
+
43
+ char = text[i]
44
+ char_code = ord(char)
45
+
46
+ # Check if character should be removed:
47
+ # 1. Nikud range: \u0591-\u05C7 (0x0591 to 0x05C7)
48
+ # 2. Punctuation: , - ? ! : . ״
49
+ should_remove = (
50
+ (0x0591 <= char_code <= 0x05C7) or
51
+ char in [',', '-', '?', '!', ':', '.', '״']
52
+ )
53
+
54
+ if should_remove:
55
+ orig_to_prep[i] = -1 # Mark as removed
56
+ else:
57
+ prep_to_orig[preprocessed_pos] = i
58
+ orig_to_prep[i] = preprocessed_pos
59
+ preprocessed += char
60
+ preprocessed_pos += 1
61
+ i += 1
62
+
63
+ return preprocessed, prep_to_orig, orig_to_prep
64
 
65
  def fetch_daf_texts(vercel_base_url: str, auth_token: str) -> list:
66
  """
 
75
  print(f"Fetching daf texts from {url}...")
76
 
77
  try:
78
+ # Include authentication token in header
79
  headers = {
80
  'x-auth-token': auth_token,
81
  'Content-Type': 'application/json'
82
  }
 
 
 
 
 
 
 
 
 
83
  response = requests.get(url, headers=headers, timeout=60)
84
  response.raise_for_status()
85
  data = response.json()
86
  print(f"Fetched {data.get('count', 0)} dafim")
87
  return data.get('dafim', [])
 
 
 
 
 
 
 
 
 
 
 
 
88
  except Exception as e:
89
  print(f"Error fetching daf texts: {e}")
90
+ if hasattr(e, 'response') and e.response is not None:
91
+ print(f"Response status: {e.response.status_code}")
92
+ print(f"Response text: {e.response.text}")
93
  raise
94
 
95
  def text_to_sequence(text: str, word_to_idx: dict) -> list:
 
105
  max_len: int = MAX_LEN
106
  ) -> list:
107
  """
108
+ Generate predictions for a single daf text (original text, not preprocessed).
109
  Returns list of ranges: [{'start': int, 'end': int, 'type': int}, ...]
110
+ Positions are relative to the original text.
111
 
112
  Strategy: Sliding window approach - predict on overlapping windows of text
113
  """
114
  model.eval()
115
 
116
+ # Preprocess the text and get character mappings
117
+ preprocessed_text, prep_to_orig, orig_to_prep = preprocess_text(daf_text)
118
 
119
+ # Split into words and track character positions accurately
120
  words = preprocessed_text.split()
121
 
122
  if len(words) == 0:
123
  return []
124
 
125
+ # Build word boundaries in preprocessed text by tracking positions as we iterate
126
+ # This is more reliable than using find() which could match wrong occurrences
127
+ word_boundaries = []
128
+ char_pos = 0
129
+ word_idx = 0
130
+
131
+ # Iterate through preprocessed text to find word boundaries
132
+ while char_pos < len(preprocessed_text) and word_idx < len(words):
133
+ # Skip leading spaces
134
+ while char_pos < len(preprocessed_text) and preprocessed_text[char_pos] == ' ':
135
+ char_pos += 1
136
+
137
+ if char_pos >= len(preprocessed_text):
138
+ break
139
+
140
+ # Find the current word
141
+ word = words[word_idx]
142
+ word_start = char_pos
143
+
144
+ # Check if the word starts at this position
145
+ if preprocessed_text[char_pos:char_pos + len(word)] == word:
146
+ word_end = char_pos + len(word)
147
+ word_boundaries.append((word_start, word_end))
148
+ char_pos = word_end
149
+ word_idx += 1
150
+ else:
151
+ # Word doesn't match - this shouldn't happen, but handle gracefully
152
+ # Try to find the word starting from current position
153
+ found_pos = preprocessed_text.find(word, char_pos)
154
+ if found_pos != -1:
155
+ word_boundaries.append((found_pos, found_pos + len(word)))
156
+ char_pos = found_pos + len(word)
157
+ word_idx += 1
158
+ else:
159
+ # Skip this word if we can't find it
160
+ break
161
+
162
  # Use sliding window approach
163
  window_size = max_len
164
  stride = window_size // 2 # 50% overlap
 
191
  _, predicted = torch.max(output.data, 1)
192
  predicted_label_idx = predicted.item()
193
 
194
+ # Calculate character positions in preprocessed text using word boundaries
195
+ # Ensure we don't go out of bounds
196
+ if i >= len(word_boundaries):
197
+ continue
 
 
 
 
 
198
 
199
+ last_word_idx = min(i + len(window_words) - 1, len(word_boundaries) - 1)
200
+ if last_word_idx < i:
201
+ continue
202
 
203
+ # Start position is the start of the first word in the window
204
+ window_start_prep = word_boundaries[i][0]
205
+ # End position is the end of the last word in the window
206
+ window_end_prep = word_boundaries[last_word_idx][1]
 
 
 
207
 
208
  # Only add if we have a valid range
209
+ if window_end_prep > window_start_prep:
210
+ # Map preprocessed text positions to original text positions
211
+ # Find the original start position
212
+ original_start = prep_to_orig.get(window_start_prep)
213
+ if original_start is None:
214
+ # Find the closest mapped position before or at window_start_prep
215
+ for prep_pos in sorted(prep_to_orig.keys(), reverse=True):
216
+ if prep_pos <= window_start_prep:
217
+ original_start = prep_to_orig[prep_pos]
218
+ break
219
+ if original_start is None:
220
+ continue # Skip if we can't map start position
221
+
222
+ # Find the original end position
223
+ # window_end_prep points to the character after the last character in the window
224
+ # We need to map this to the original text
225
+ window_end_prep_clamped = min(window_end_prep, len(preprocessed_text))
226
+
227
+ # Find the original position corresponding to the end of the window
228
+ # If window_end_prep_clamped is at the end of preprocessed text, use end of original text
229
+ if window_end_prep_clamped >= len(preprocessed_text):
230
+ original_end = len(daf_text)
231
+ else:
232
+ # Find the original position for the character at window_end_prep_clamped
233
+ # (the character right after the window ends)
234
+ end_char_orig = prep_to_orig.get(window_end_prep_clamped)
235
+ if end_char_orig is not None:
236
+ original_end = end_char_orig
237
+ else:
238
+ # Character at window_end_prep_clamped was removed, find the next non-removed character
239
+ # Look for the next preprocessed position >= window_end_prep_clamped
240
+ next_prep_pos = None
241
+ for prep_pos in sorted(prep_to_orig.keys()):
242
+ if prep_pos >= window_end_prep_clamped:
243
+ next_prep_pos = prep_pos
244
+ break
245
+
246
+ if next_prep_pos is not None:
247
+ original_end = prep_to_orig[next_prep_pos]
248
+ else:
249
+ # No more characters in preprocessed text, use end of original text
250
+ original_end = len(daf_text)
251
+
252
+ # Ensure end is after start and within bounds
253
+ if original_end <= original_start:
254
+ # Fallback: ensure at least one character
255
+ original_end = min(original_start + 1, len(daf_text))
256
+ original_end = min(original_end, len(daf_text))
257
+
258
  ranges.append({
259
+ 'start': original_start,
260
+ 'end': original_end,
261
  'type': int(predicted_label_idx)
262
  })
263
 
 
293
  auth_token: str
294
  ) -> list:
295
  """
296
+ DEPRECATED: This function is no longer used in the training flow.
297
+ It's kept for reference but should not be called.
298
+
299
  Generate predictions for all dafim.
300
  Returns list of prediction objects: [{'daf_id': str, 'ranges': [...]}, ...]
301
 
302
+ NOTE: This function expects preprocessed text from the API, but generate_predictions_for_daf
303
+ now expects original text. This function needs to be updated if it's ever used again.
304
+
305
  Args:
306
  model: Trained model
307
  word_to_idx: Word to index mapping
 
309
  vercel_base_url: Base URL of the Vercel app
310
  auth_token: Authentication token for Vercel API (TRAINING_CALLBACK_TOKEN)
311
  """
312
+ print("WARNING: generate_all_predictions is deprecated and may not work correctly.")
313
  print("Fetching daf texts from Vercel...")
314
  dafim = fetch_daf_texts(vercel_base_url, auth_token)
315
 
 
327
 
328
  try:
329
  daf_id = daf['id']
330
+ # NOTE: The API returns preprocessed text, but generate_predictions_for_daf
331
+ # now expects original text. This will cause incorrect character position mapping.
332
+ # This function should fetch original text or be updated to handle preprocessed text.
333
+ text_content = daf['text_content']
334
 
335
  ranges = generate_predictions_for_daf(
336
  model, text_content, word_to_idx, label_encoder