Spaces:
Sleeping
Sleeping
Commit
·
a67ba36
1
Parent(s):
4992a8e
split those sentences
Browse files
app.py
CHANGED
|
@@ -6,10 +6,11 @@ import logging
|
|
| 6 |
import torch
|
| 7 |
import nltk
|
| 8 |
import os
|
| 9 |
-
|
| 10 |
|
| 11 |
from nltk.tokenize import sent_tokenize
|
| 12 |
|
|
|
|
| 13 |
nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
|
| 14 |
nltk.data.path.append(nltk_data_path)
|
| 15 |
|
|
@@ -28,7 +29,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
| 28 |
|
| 29 |
# Token constraints
|
| 30 |
MAX_MODEL_TOKENS = 1024
|
| 31 |
-
SAFE_CHUNK_SIZE =
|
| 32 |
|
| 33 |
# Pydantic schemas
|
| 34 |
class SummarizationItem(BaseModel):
|
|
@@ -45,10 +46,40 @@ class SummarizationResponseItem(BaseModel):
|
|
| 45 |
class BatchSummarizationResponse(BaseModel):
|
| 46 |
summaries: List[SummarizationResponseItem]
|
| 47 |
|
| 48 |
-
# Sentence
|
| 49 |
-
def split_sentences(text: str) -> list[str]:
|
| 50 |
-
|
|
|
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
| 53 |
sentences = split_sentences(text)
|
| 54 |
chunks = []
|
|
@@ -56,7 +87,7 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
|
| 56 |
|
| 57 |
for sentence in sentences:
|
| 58 |
tentative_chunk = " ".join(current_chunk_sentences + [sentence])
|
| 59 |
-
token_count = len(tokenizer.
|
| 60 |
|
| 61 |
if token_count <= max_tokens:
|
| 62 |
current_chunk_sentences.append(sentence)
|
|
@@ -68,12 +99,11 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
|
| 68 |
if current_chunk_sentences:
|
| 69 |
chunks.append(" ".join(current_chunk_sentences))
|
| 70 |
|
| 71 |
-
# Final
|
| 72 |
final_chunks = []
|
| 73 |
for chunk in chunks:
|
| 74 |
encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
|
| 75 |
token_len = encoded["input_ids"].shape[1]
|
| 76 |
-
|
| 77 |
if token_len <= MAX_MODEL_TOKENS:
|
| 78 |
final_chunks.append(chunk)
|
| 79 |
else:
|
|
@@ -98,7 +128,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
|
|
| 98 |
logger.error("No valid chunks after filtering. Returning empty response.")
|
| 99 |
return {"summaries": []}
|
| 100 |
|
| 101 |
-
#
|
| 102 |
summaries = summarizer(
|
| 103 |
all_chunks,
|
| 104 |
max_length=150,
|
|
@@ -108,7 +138,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
|
|
| 108 |
batch_size=4
|
| 109 |
)
|
| 110 |
|
| 111 |
-
#
|
| 112 |
summary_map = {}
|
| 113 |
for content_id, result in zip(chunk_map, summaries):
|
| 114 |
summary_map.setdefault(content_id, []).append(result["summary_text"])
|
|
|
|
| 6 |
import torch
|
| 7 |
import nltk
|
| 8 |
import os
|
| 9 |
+
import re
|
| 10 |
|
| 11 |
from nltk.tokenize import sent_tokenize
|
| 12 |
|
| 13 |
+
# Configure NLTK to use preloaded data path
|
| 14 |
nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
|
| 15 |
nltk.data.path.append(nltk_data_path)
|
| 16 |
|
|
|
|
| 29 |
|
| 30 |
# Token constraints
|
| 31 |
MAX_MODEL_TOKENS = 1024
|
| 32 |
+
SAFE_CHUNK_SIZE = 600 # Reduced to leave room for special tokens
|
| 33 |
|
| 34 |
# Pydantic schemas
|
| 35 |
class SummarizationItem(BaseModel):
|
|
|
|
| 46 |
class BatchSummarizationResponse(BaseModel):
|
| 47 |
summaries: List[SummarizationResponseItem]
|
| 48 |
|
| 49 |
+
# Sentence splitter with fallback for long sentences
|
| 50 |
+
def split_sentences(text: str, max_sentence_tokens: int = SAFE_CHUNK_SIZE) -> list[str]:
|
| 51 |
+
sentences = sent_tokenize(text.strip())
|
| 52 |
+
split_results = []
|
| 53 |
|
| 54 |
+
for sentence in sentences:
|
| 55 |
+
token_len = len(tokenizer.tokenize(sentence))
|
| 56 |
+
if token_len <= max_sentence_tokens:
|
| 57 |
+
split_results.append(sentence)
|
| 58 |
+
else:
|
| 59 |
+
# Fallback: split by commas/semicolons
|
| 60 |
+
sub_sentences = re.split(r'[;,:]\s+', sentence)
|
| 61 |
+
for sub in sub_sentences:
|
| 62 |
+
sub = sub.strip()
|
| 63 |
+
if not sub:
|
| 64 |
+
continue
|
| 65 |
+
if len(tokenizer.tokenize(sub)) <= max_sentence_tokens:
|
| 66 |
+
split_results.append(sub)
|
| 67 |
+
else:
|
| 68 |
+
# Final fallback: hard-split by word
|
| 69 |
+
words = sub.split()
|
| 70 |
+
buffer = []
|
| 71 |
+
for word in words:
|
| 72 |
+
buffer.append(word)
|
| 73 |
+
current = " ".join(buffer)
|
| 74 |
+
if len(tokenizer.tokenize(current)) > max_sentence_tokens:
|
| 75 |
+
split_results.append(" ".join(buffer[:-1]))
|
| 76 |
+
buffer = [word]
|
| 77 |
+
if buffer:
|
| 78 |
+
split_results.append(" ".join(buffer))
|
| 79 |
+
|
| 80 |
+
return split_results
|
| 81 |
+
|
| 82 |
+
# Chunking based on token length
|
| 83 |
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
| 84 |
sentences = split_sentences(text)
|
| 85 |
chunks = []
|
|
|
|
| 87 |
|
| 88 |
for sentence in sentences:
|
| 89 |
tentative_chunk = " ".join(current_chunk_sentences + [sentence])
|
| 90 |
+
token_count = len(tokenizer.tokenize(tentative_chunk))
|
| 91 |
|
| 92 |
if token_count <= max_tokens:
|
| 93 |
current_chunk_sentences.append(sentence)
|
|
|
|
| 99 |
if current_chunk_sentences:
|
| 100 |
chunks.append(" ".join(current_chunk_sentences))
|
| 101 |
|
| 102 |
+
# Final model-safe filtering
|
| 103 |
final_chunks = []
|
| 104 |
for chunk in chunks:
|
| 105 |
encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
|
| 106 |
token_len = encoded["input_ids"].shape[1]
|
|
|
|
| 107 |
if token_len <= MAX_MODEL_TOKENS:
|
| 108 |
final_chunks.append(chunk)
|
| 109 |
else:
|
|
|
|
| 128 |
logger.error("No valid chunks after filtering. Returning empty response.")
|
| 129 |
return {"summaries": []}
|
| 130 |
|
| 131 |
+
# Inference
|
| 132 |
summaries = summarizer(
|
| 133 |
all_chunks,
|
| 134 |
max_length=150,
|
|
|
|
| 138 |
batch_size=4
|
| 139 |
)
|
| 140 |
|
| 141 |
+
# Merge summaries by content_id
|
| 142 |
summary_map = {}
|
| 143 |
for content_id, result in zip(chunk_map, summaries):
|
| 144 |
summary_map.setdefault(content_id, []).append(result["summary_text"])
|