app-estudio / services /summarization_service.py
igna7's picture
add app files
0ce0464 verified
import logging
import torch
from transformers import BertTokenizerFast, EncoderDecoderModel
logger = logging.getLogger(__name__)
class SummarizationService:
def __init__(self):
# VOLVEMOS AL MODELO ORIGINAL (Mejor sintaxis en español que mT5)
ckpt = "mrm8488/bert2bert_shared-spanish-finetuned-summarization"
self.device = torch.device("cpu")
logger.info(f"Cargando modelo original BERT2BERT: {ckpt}...")
self.tokenizer = BertTokenizerFast.from_pretrained(ckpt)
self.model = EncoderDecoderModel.from_pretrained(
ckpt,
low_cpu_mem_usage=False,
use_safetensors=False,
torch_dtype=torch.float32,
)
self.model.eval()
logger.info("Modelo BERT2BERT cargado correctamente.")
def summarize(self, text: str) -> str:
"""
Resume el texto usando strategy de 'chunking' con el modelo BERT2BERT original.
Soluciona el problema de límite de 512 tokens.
"""
text = text.replace("\n", " ").strip()
# Parámetros ajustados para MICRO-CHUNKS
# Al ser chunks pequeños, no queremos resúmenes de 1 línea, sino algo sustancial
gen_params = {
"min_length": 25,
"max_length": 100,
"num_beams": 4,
"length_penalty": 2.0, # Forzar a escribir más
"no_repeat_ngram_size": 3,
"early_stopping": True
}
# Micro-Chunking: Dividir en trozos pequeños (200 tokens)
# Esto obliga al modelo a procesar CADA parte del texto, sin saltarse trozos
chunks = self._chunk_text(text, max_tokens=200)
logger.info(f"Texto dividido en {len(chunks)} micro-fragmentos para mayor detalle.")
summaries = []
for i, chunk in enumerate(chunks):
# Tokenizar
inputs = self.tokenizer(
[chunk],
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
with torch.no_grad():
output_ids = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**gen_params
)
summary_piece = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
summary_piece = summary_piece.strip()
if summary_piece:
summaries.append(summary_piece)
logger.info(f"Fragmento {i+1} resumido.")
final_summary = " ".join(summaries)
return final_summary
def _chunk_text(self, text: str, max_tokens: int) -> list[str]:
"""
Divide el texto en fragmentos seguros para BERT.
"""
sentences = text.split('. ')
chunks = []
current_chunk = []
current_length_tokens = 0
for sentence in sentences:
sentence = sentence.strip()
if not sentence: continue
# Tokenización real para conteo preciso (importante en BERT)
tokens = self.tokenizer.tokenize(sentence)
sent_len = len(tokens)
# Si una frase es gigante, la metemos sola
if sent_len > max_tokens:
if current_chunk:
chunks.append(". ".join(current_chunk) + ".")
current_chunk = []
current_length_tokens = 0
chunks.append(sentence + ".")
continue
if current_length_tokens + sent_len > max_tokens:
chunks.append(". ".join(current_chunk) + ".")
current_chunk = [sentence]
current_length_tokens = sent_len
else:
current_chunk.append(sentence)
current_length_tokens += sent_len
if current_chunk:
chunks.append(". ".join(current_chunk) + ".")
return chunks