Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import pipeline, AutoTokenizer | |
| from typing import List | |
| import logging | |
| import torch | |
| import nltk | |
| import os | |
| import re | |
| from nltk.tokenize import sent_tokenize | |
| # Configure NLTK to use preloaded data path | |
| nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data") | |
| nltk.data.path.append(nltk_data_path) | |
| app = FastAPI() | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("summarizer") | |
| # Load model and tokenizer | |
| model_name = "sshleifer/distilbart-cnn-12-6" | |
| device = 0 if torch.cuda.is_available() else -1 | |
| logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}") | |
| summarizer = pipeline("summarization", model=model_name, device=device) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Token constraints | |
| MAX_MODEL_TOKENS = 1024 | |
| SAFE_CHUNK_SIZE = 600 # Safe for aggregation | |
| TRUNCATED_TOKENS = MAX_MODEL_TOKENS - 2 # Leave room for special tokens | |
| # Pydantic schemas | |
| class SummarizationItem(BaseModel): | |
| content_id: str | |
| text: str | |
| class BatchSummarizationRequest(BaseModel): | |
| inputs: List[SummarizationItem] | |
| class SummarizationResponseItem(BaseModel): | |
| content_id: str | |
| summary: str | |
| class BatchSummarizationResponse(BaseModel): | |
| summaries: List[SummarizationResponseItem] | |
| # Sentence splitter with fallback for long sentences | |
| def split_sentences(text: str, max_sentence_tokens: int = SAFE_CHUNK_SIZE) -> list[str]: | |
| sentences = sent_tokenize(text.strip()) | |
| split_results = [] | |
| for sentence in sentences: | |
| token_len = len(tokenizer.tokenize(sentence)) | |
| if token_len <= max_sentence_tokens: | |
| split_results.append(sentence) | |
| else: | |
| # Fallback: split by commas/semicolons | |
| sub_sentences = re.split(r'[;,:]\s+', sentence) | |
| for sub in sub_sentences: | |
| sub = sub.strip() | |
| if not sub: | |
| continue | |
| if len(tokenizer.tokenize(sub)) <= max_sentence_tokens: | |
| split_results.append(sub) | |
| else: | |
| # Final fallback: hard-split by word | |
| words = sub.split() | |
| buffer = [] | |
| for word in words: | |
| buffer.append(word) | |
| current = " ".join(buffer) | |
| if len(tokenizer.tokenize(current)) > max_sentence_tokens: | |
| split_results.append(" ".join(buffer[:-1])) | |
| buffer = [word] | |
| if buffer: | |
| split_results.append(" ".join(buffer)) | |
| return split_results | |
| # Truncate text safely at token-level | |
| def truncate_text(text: str, max_tokens: int = TRUNCATED_TOKENS) -> str: | |
| tokens = tokenizer.encode(text, add_special_tokens=False) | |
| if len(tokens) <= max_tokens: | |
| return text | |
| truncated = tokens[:max_tokens] | |
| return tokenizer.decode(truncated, skip_special_tokens=True) | |
| # Chunking based on token length | |
| def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]: | |
| sentences = split_sentences(text) | |
| chunks = [] | |
| current_chunk_sentences = [] | |
| for sentence in sentences: | |
| tentative_chunk = " ".join(current_chunk_sentences + [sentence]) | |
| token_count = len(tokenizer.tokenize(tentative_chunk)) | |
| if token_count <= max_tokens: | |
| current_chunk_sentences.append(sentence) | |
| else: | |
| if current_chunk_sentences: | |
| chunks.append(" ".join(current_chunk_sentences)) | |
| current_chunk_sentences = [sentence] | |
| if current_chunk_sentences: | |
| chunks.append(" ".join(current_chunk_sentences)) | |
| # Final model-safe filtering | |
| final_chunks = [] | |
| for chunk in chunks: | |
| encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False) | |
| token_len = encoded["input_ids"].shape[1] | |
| if token_len <= MAX_MODEL_TOKENS: | |
| final_chunks.append(chunk) | |
| else: | |
| logger.warning(f"[CHUNKING] Dropped oversized chunk ({token_len} tokens): {chunk[:100]}...") | |
| return final_chunks | |
| async def summarize_batch(request: BatchSummarizationRequest): | |
| all_chunks = [] | |
| chunk_map = [] | |
| for item in request.inputs: | |
| chunks = chunk_text(item.text) | |
| logger.info(f"[CHUNKING] content_id={item.content_id} num_chunks={len(chunks)}") | |
| for chunk in chunks: | |
| all_chunks.append(truncate_text(chunk)) # ✅ enforce max length | |
| chunk_map.append(item.content_id) | |
| if not all_chunks: | |
| logger.error("No valid chunks after filtering. Returning empty response.") | |
| return {"summaries": []} | |
| # Inference | |
| summaries = summarizer( | |
| all_chunks, | |
| max_length=150, | |
| min_length=30, | |
| truncation=True, | |
| do_sample=False, | |
| batch_size=4 | |
| ) | |
| # Merge summaries by content_id | |
| summary_map = {} | |
| for content_id, result in zip(chunk_map, summaries): | |
| summary_map.setdefault(content_id, []).append(result["summary_text"]) | |
| response_items = [ | |
| SummarizationResponseItem( | |
| content_id=cid, | |
| summary=" ".join(parts) | |
| ) | |
| for cid, parts in summary_map.items() | |
| ] | |
| return {"summaries": response_items} | |
| def greet_json(): | |
| return {"message": "DistilBART Batch Summarizer API is running"} | |