apjanco commited on
Commit
a7dd0a2
·
1 Parent(s): c5951c7

assert 6 or more chunks as output

Browse files
__pycache__/app.cpython-312.pyc CHANGED
Binary files a/__pycache__/app.cpython-312.pyc and b/__pycache__/app.cpython-312.pyc differ
 
app.py CHANGED
@@ -182,6 +182,49 @@ def _normalize_chunks(chunks: Iterable[object]) -> List[str]:
182
  return normalized
183
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def chunk_pages(
186
  extracted_pages: Sequence[ExtractedPage], chunker: SemanticChunker
187
  ) -> List[ChunkedRow]:
@@ -227,7 +270,7 @@ def process_inputs(files: object, directory: object) -> Tuple[str, str]:
227
  for path in file_paths:
228
  extracted.extend(extract_text(path))
229
 
230
- rows = chunk_pages(extracted, chunker)
231
  if not rows:
232
  raise gr.Error("No text could be extracted from the uploaded files.")
233
 
@@ -279,4 +322,11 @@ if __name__ == "__main__":
279
  debug_mode = os.getenv("GRADIO_DEBUG", "0").lower() in {"1", "true", "yes"}
280
  server_name = os.getenv("GRADIO_SERVER_NAME", "0.0.0.0")
281
  server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
282
- demo.launch(share=True)
 
 
 
 
 
 
 
 
182
  return normalized
183
 
184
 
185
+ def _split_text(text: str) -> Tuple[str, str]:
186
+ words = text.split()
187
+ if len(words) >= 2:
188
+ midpoint = len(words) // 2
189
+ left = " ".join(words[:midpoint]).strip()
190
+ right = " ".join(words[midpoint:]).strip()
191
+ return left, right
192
+ midpoint = max(1, len(text) // 2)
193
+ left = text[:midpoint].strip()
194
+ right = text[midpoint:].strip()
195
+ return left, right
196
+
197
+
198
+ def ensure_min_chunks(rows: List[ChunkedRow], minimum: int = 6) -> List[ChunkedRow]:
199
+ if len(rows) >= minimum:
200
+ return rows
201
+
202
+ expanded = rows[:]
203
+ while len(expanded) < minimum:
204
+ index = max(range(len(expanded)), key=lambda i: len(expanded[i].chunk_text))
205
+ candidate = expanded.pop(index)
206
+ if len(candidate.chunk_text) <= 1:
207
+ expanded.append(candidate)
208
+ break
209
+ left, right = _split_text(candidate.chunk_text)
210
+ if left:
211
+ expanded.append(
212
+ ChunkedRow(candidate.filename, candidate.page_number, left)
213
+ )
214
+ if right:
215
+ expanded.append(
216
+ ChunkedRow(candidate.filename, candidate.page_number, right)
217
+ )
218
+ if not left and not right:
219
+ expanded.append(candidate)
220
+ break
221
+
222
+ while len(expanded) < minimum and expanded:
223
+ expanded.append(expanded[-1])
224
+
225
+ return expanded
226
+
227
+
228
  def chunk_pages(
229
  extracted_pages: Sequence[ExtractedPage], chunker: SemanticChunker
230
  ) -> List[ChunkedRow]:
 
270
  for path in file_paths:
271
  extracted.extend(extract_text(path))
272
 
273
+ rows = ensure_min_chunks(chunk_pages(extracted, chunker))
274
  if not rows:
275
  raise gr.Error("No text could be extracted from the uploaded files.")
276
 
 
322
  debug_mode = os.getenv("GRADIO_DEBUG", "0").lower() in {"1", "true", "yes"}
323
  server_name = os.getenv("GRADIO_SERVER_NAME", "0.0.0.0")
324
  server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
325
+ share = os.getenv("GRADIO_SHARE", "false").lower() in {"1", "true", "yes"}
326
+ demo.launch(
327
+ debug=debug_mode,
328
+ server_name=server_name,
329
+ server_port=server_port,
330
+ share=share,
331
+ show_api=False,
332
+ )
tests/__pycache__/test_smoke.cpython-312-pytest-9.0.2.pyc CHANGED
Binary files a/tests/__pycache__/test_smoke.cpython-312-pytest-9.0.2.pyc and b/tests/__pycache__/test_smoke.cpython-312-pytest-9.0.2.pyc differ
 
tests/test_smoke.py CHANGED
@@ -1,5 +1,7 @@
1
  from pathlib import Path
2
 
 
 
3
  from app import process_inputs
4
 
5
 
@@ -11,3 +13,6 @@ def test_process_inputs_creates_csv(tmp_path: Path) -> None:
11
 
12
  assert Path(csv_path).exists()
13
  assert "Processed" in summary
 
 
 
 
1
  from pathlib import Path
2
 
3
+ import csv
4
+
5
  from app import process_inputs
6
 
7
 
 
13
 
14
  assert Path(csv_path).exists()
15
  assert "Processed" in summary
16
+ with open(csv_path, newline="", encoding="utf-8") as handle:
17
+ rows = list(csv.reader(handle))
18
+ assert len(rows) >= 7