Spaces:
Sleeping
Sleeping
Commit ·
372b4a1
1
Parent(s): cba823e
come on
Browse files
app.py
CHANGED
|
@@ -4,7 +4,14 @@ from transformers import pipeline, AutoTokenizer
|
|
| 4 |
from typing import List
|
| 5 |
import logging
|
| 6 |
import torch
|
| 7 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
app = FastAPI()
|
| 10 |
|
|
@@ -21,7 +28,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
| 21 |
|
| 22 |
# Token constraints
|
| 23 |
MAX_MODEL_TOKENS = 1024
|
| 24 |
-
SAFE_CHUNK_SIZE =
|
| 25 |
|
| 26 |
# Pydantic schemas
|
| 27 |
class SummarizationItem(BaseModel):
|
|
@@ -38,9 +45,9 @@ class SummarizationResponseItem(BaseModel):
|
|
| 38 |
class BatchSummarizationResponse(BaseModel):
|
| 39 |
summaries: List[SummarizationResponseItem]
|
| 40 |
|
| 41 |
-
# Sentence-based chunking
|
| 42 |
def split_sentences(text: str) -> list[str]:
|
| 43 |
-
return
|
| 44 |
|
| 45 |
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
| 46 |
sentences = split_sentences(text)
|
|
@@ -49,7 +56,7 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
|
| 49 |
|
| 50 |
for sentence in sentences:
|
| 51 |
tentative_chunk = " ".join(current_chunk_sentences + [sentence])
|
| 52 |
-
token_count = len(tokenizer.encode(tentative_chunk,
|
| 53 |
|
| 54 |
if token_count <= max_tokens:
|
| 55 |
current_chunk_sentences.append(sentence)
|
|
@@ -64,16 +71,16 @@ def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
|
| 64 |
# Final filter: ensure nothing slipped through
|
| 65 |
final_chunks = []
|
| 66 |
for chunk in chunks:
|
| 67 |
-
encoded = tokenizer(chunk, return_tensors="pt", truncation=False)
|
| 68 |
token_len = encoded["input_ids"].shape[1]
|
|
|
|
| 69 |
if token_len <= MAX_MODEL_TOKENS:
|
| 70 |
final_chunks.append(chunk)
|
| 71 |
else:
|
| 72 |
-
logger.warning(f"[CHUNKING] Dropped oversized chunk
|
| 73 |
|
| 74 |
return final_chunks
|
| 75 |
|
| 76 |
-
# Summarization endpoint
|
| 77 |
@app.post("/summarize", response_model=BatchSummarizationResponse)
|
| 78 |
async def summarize_batch(request: BatchSummarizationRequest):
|
| 79 |
all_chunks = []
|
|
@@ -91,6 +98,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
|
|
| 91 |
logger.error("No valid chunks after filtering. Returning empty response.")
|
| 92 |
return {"summaries": []}
|
| 93 |
|
|
|
|
| 94 |
summaries = summarizer(
|
| 95 |
all_chunks,
|
| 96 |
max_length=150,
|
|
@@ -100,6 +108,7 @@ async def summarize_batch(request: BatchSummarizationRequest):
|
|
| 100 |
batch_size=4
|
| 101 |
)
|
| 102 |
|
|
|
|
| 103 |
summary_map = {}
|
| 104 |
for content_id, result in zip(chunk_map, summaries):
|
| 105 |
summary_map.setdefault(content_id, []).append(result["summary_text"])
|
|
|
|
| 4 |
from typing import List
|
| 5 |
import logging
|
| 6 |
import torch
|
| 7 |
+
import nltk
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from nltk.tokenize import sent_tokenize
|
| 11 |
+
|
| 12 |
+
# Download punkt tokenizer if not already present
|
| 13 |
+
nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
|
| 14 |
+
nltk.download("punkt", download_dir=nltk_data_path)
|
| 15 |
|
| 16 |
app = FastAPI()
|
| 17 |
|
|
|
|
| 28 |
|
| 29 |
# Token constraints
|
| 30 |
MAX_MODEL_TOKENS = 1024
|
| 31 |
+
SAFE_CHUNK_SIZE = 650 # Lowered for extra safety
|
| 32 |
|
| 33 |
# Pydantic schemas
|
| 34 |
class SummarizationItem(BaseModel):
|
|
|
|
| 45 |
class BatchSummarizationResponse(BaseModel):
|
| 46 |
summaries: List[SummarizationResponseItem]
|
| 47 |
|
| 48 |
+
# Sentence-based chunking using nltk
|
| 49 |
def split_sentences(text: str) -> list[str]:
|
| 50 |
+
return sent_tokenize(text.strip())
|
| 51 |
|
| 52 |
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
|
| 53 |
sentences = split_sentences(text)
|
|
|
|
| 56 |
|
| 57 |
for sentence in sentences:
|
| 58 |
tentative_chunk = " ".join(current_chunk_sentences + [sentence])
|
| 59 |
+
token_count = len(tokenizer.encode(tentative_chunk, add_special_tokens=False))
|
| 60 |
|
| 61 |
if token_count <= max_tokens:
|
| 62 |
current_chunk_sentences.append(sentence)
|
|
|
|
| 71 |
# Final filter: ensure nothing slipped through
|
| 72 |
final_chunks = []
|
| 73 |
for chunk in chunks:
|
| 74 |
+
encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
|
| 75 |
token_len = encoded["input_ids"].shape[1]
|
| 76 |
+
|
| 77 |
if token_len <= MAX_MODEL_TOKENS:
|
| 78 |
final_chunks.append(chunk)
|
| 79 |
else:
|
| 80 |
+
logger.warning(f"[CHUNKING] Dropped oversized chunk ({token_len} tokens): {chunk[:100]}...")
|
| 81 |
|
| 82 |
return final_chunks
|
| 83 |
|
|
|
|
| 84 |
@app.post("/summarize", response_model=BatchSummarizationResponse)
|
| 85 |
async def summarize_batch(request: BatchSummarizationRequest):
|
| 86 |
all_chunks = []
|
|
|
|
| 98 |
logger.error("No valid chunks after filtering. Returning empty response.")
|
| 99 |
return {"summaries": []}
|
| 100 |
|
| 101 |
+
# Batch inference (safe, since we're now filtering properly)
|
| 102 |
summaries = summarizer(
|
| 103 |
all_chunks,
|
| 104 |
max_length=150,
|
|
|
|
| 108 |
batch_size=4
|
| 109 |
)
|
| 110 |
|
| 111 |
+
# Combine summaries by content_id
|
| 112 |
summary_map = {}
|
| 113 |
for content_id, result in zip(chunk_map, summaries):
|
| 114 |
summary_map.setdefault(content_id, []).append(result["summary_text"])
|