Spaces:
Sleeping
Sleeping
| import logging | |
| import torch | |
| from transformers import BertTokenizerFast, EncoderDecoderModel | |
| logger = logging.getLogger(__name__) | |
| class SummarizationService: | |
| def __init__(self): | |
| # VOLVEMOS AL MODELO ORIGINAL (Mejor sintaxis en español que mT5) | |
| ckpt = "mrm8488/bert2bert_shared-spanish-finetuned-summarization" | |
| self.device = torch.device("cpu") | |
| logger.info(f"Cargando modelo original BERT2BERT: {ckpt}...") | |
| self.tokenizer = BertTokenizerFast.from_pretrained(ckpt) | |
| self.model = EncoderDecoderModel.from_pretrained( | |
| ckpt, | |
| low_cpu_mem_usage=False, | |
| use_safetensors=False, | |
| torch_dtype=torch.float32, | |
| ) | |
| self.model.eval() | |
| logger.info("Modelo BERT2BERT cargado correctamente.") | |
| def summarize(self, text: str) -> str: | |
| """ | |
| Resume el texto usando strategy de 'chunking' con el modelo BERT2BERT original. | |
| Soluciona el problema de límite de 512 tokens. | |
| """ | |
| text = text.replace("\n", " ").strip() | |
| # Parámetros ajustados para MICRO-CHUNKS | |
| # Al ser chunks pequeños, no queremos resúmenes de 1 línea, sino algo sustancial | |
| gen_params = { | |
| "min_length": 25, | |
| "max_length": 100, | |
| "num_beams": 4, | |
| "length_penalty": 2.0, # Forzar a escribir más | |
| "no_repeat_ngram_size": 3, | |
| "early_stopping": True | |
| } | |
| # Micro-Chunking: Dividir en trozos pequeños (200 tokens) | |
| # Esto obliga al modelo a procesar CADA parte del texto, sin saltarse trozos | |
| chunks = self._chunk_text(text, max_tokens=200) | |
| logger.info(f"Texto dividido en {len(chunks)} micro-fragmentos para mayor detalle.") | |
| summaries = [] | |
| for i, chunk in enumerate(chunks): | |
| # Tokenizar | |
| inputs = self.tokenizer( | |
| [chunk], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ) | |
| input_ids = inputs["input_ids"].to(self.device) | |
| attention_mask = inputs["attention_mask"].to(self.device) | |
| with torch.no_grad(): | |
| output_ids = self.model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| **gen_params | |
| ) | |
| summary_piece = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| summary_piece = summary_piece.strip() | |
| if summary_piece: | |
| summaries.append(summary_piece) | |
| logger.info(f"Fragmento {i+1} resumido.") | |
| final_summary = " ".join(summaries) | |
| return final_summary | |
| def _chunk_text(self, text: str, max_tokens: int) -> list[str]: | |
| """ | |
| Divide el texto en fragmentos seguros para BERT. | |
| """ | |
| sentences = text.split('. ') | |
| chunks = [] | |
| current_chunk = [] | |
| current_length_tokens = 0 | |
| for sentence in sentences: | |
| sentence = sentence.strip() | |
| if not sentence: continue | |
| # Tokenización real para conteo preciso (importante en BERT) | |
| tokens = self.tokenizer.tokenize(sentence) | |
| sent_len = len(tokens) | |
| # Si una frase es gigante, la metemos sola | |
| if sent_len > max_tokens: | |
| if current_chunk: | |
| chunks.append(". ".join(current_chunk) + ".") | |
| current_chunk = [] | |
| current_length_tokens = 0 | |
| chunks.append(sentence + ".") | |
| continue | |
| if current_length_tokens + sent_len > max_tokens: | |
| chunks.append(". ".join(current_chunk) + ".") | |
| current_chunk = [sentence] | |
| current_length_tokens = sent_len | |
| else: | |
| current_chunk.append(sentence) | |
| current_length_tokens += sent_len | |
| if current_chunk: | |
| chunks.append(". ".join(current_chunk) + ".") | |
| return chunks | |