""" Generador de texto usando modelos GPT locales """ import torch from typing import List, Dict import logging logger = logging.getLogger(__name__) class TextGenerator: def __init__(self, model_loader): self.model_loader = model_loader self.chat_history_ids = None def generate_response(self, user_input: str, **kwargs) -> str: """ Genera una respuesta basada en la entrada del usuario Args: user_input (str): Mensaje del usuario **kwargs: Parámetros de generación (max_length, temperature, etc.) Returns: str: Respuesta generada """ if not self.model_loader.is_loaded(): return "Error: No hay modelo cargado" try: # Parámetros por defecto max_length = kwargs.get('max_length', 512) temperature = kwargs.get('temperature', 0.7) top_p = kwargs.get('top_p', 0.9) do_sample = kwargs.get('do_sample', True) # Codificar la entrada del usuario new_user_input_ids = self.model_loader.tokenizer.encode( user_input + self.model_loader.tokenizer.eos_token, return_tensors='pt' ).to(self.model_loader.device) # Concatenar con el historial de chat if self.chat_history_ids is not None: bot_input_ids = torch.cat([self.chat_history_ids, new_user_input_ids], dim=-1) else: bot_input_ids = new_user_input_ids # Generar respuesta with torch.no_grad(): chat_history_ids = self.model_loader.model.generate( bot_input_ids, max_length=max_length, num_beams=1, do_sample=do_sample, temperature=temperature, top_p=top_p, pad_token_id=self.model_loader.tokenizer.eos_token_id, attention_mask=torch.ones(bot_input_ids.shape, device=self.model_loader.device) ) # Actualizar historial self.chat_history_ids = chat_history_ids # Decodificar solo la nueva respuesta response = self.model_loader.tokenizer.decode( chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True ) return str(response).strip() except Exception as e: logger.error(f"Error en la generación: {str(e)}") return f"Error al generar respuesta: {str(e)}" def generate_text(self, prompt: str, **kwargs) -> str: """ Genera texto continuando un prompt (sin historial de chat) Args: prompt (str): Texto inicial **kwargs: Parámetros de generación Returns: str: Texto generado """ if not self.model_loader.is_loaded(): return "Error: No hay modelo cargado" try: # Parámetros por defecto max_length = kwargs.get('max_length', 100) temperature = kwargs.get('temperature', 0.8) top_p = kwargs.get('top_p', 0.9) do_sample = kwargs.get('do_sample', True) # Codificar el prompt input_ids = self.model_loader.tokenizer.encode( prompt, return_tensors='pt' ).to(self.model_loader.device) # Generar texto with torch.no_grad(): output = self.model_loader.model.generate( input_ids, max_length=input_ids.shape[1] + max_length, do_sample=do_sample, temperature=temperature, top_p=top_p, pad_token_id=self.model_loader.tokenizer.eos_token_id, attention_mask=torch.ones(input_ids.shape, device=self.model_loader.device) ) # Decodificar solo el texto generado generated_text = self.model_loader.tokenizer.decode( output[0][input_ids.shape[1]:], skip_special_tokens=True ) return str(generated_text.strip()) except Exception as e: logger.error(f"Error en la generación: {str(e)}") return f"Error al generar texto: {str(e)}" def reset_chat_history(self): """Reinicia el historial de chat""" self.chat_history_ids = None logger.info("Historial de chat reiniciado") def get_generation_stats(self) -> Dict: """Retorna estadísticas de generación""" if self.chat_history_ids is not None: return { "history_length": self.chat_history_ids.shape[1], "device": str(self.chat_history_ids.device) } return {"history_length": 0, "device": "N/A"}