File size: 5,149 Bytes
22ca508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
"""
Generador de texto usando modelos GPT locales
"""
import torch
from typing import List, Dict
import logging
logger = logging.getLogger(__name__)
class TextGenerator:
def __init__(self, model_loader):
self.model_loader = model_loader
self.chat_history_ids = None
def generate_response(self, user_input: str, **kwargs) -> str:
"""
Genera una respuesta basada en la entrada del usuario
Args:
user_input (str): Mensaje del usuario
**kwargs: Par谩metros de generaci贸n (max_length, temperature, etc.)
Returns:
str: Respuesta generada
"""
if not self.model_loader.is_loaded():
return "Error: No hay modelo cargado"
try:
# Par谩metros por defecto
max_length = kwargs.get('max_length', 512)
temperature = kwargs.get('temperature', 0.7)
top_p = kwargs.get('top_p', 0.9)
do_sample = kwargs.get('do_sample', True)
# Codificar la entrada del usuario
new_user_input_ids = self.model_loader.tokenizer.encode(
user_input + self.model_loader.tokenizer.eos_token,
return_tensors='pt'
).to(self.model_loader.device)
# Concatenar con el historial de chat
if self.chat_history_ids is not None:
bot_input_ids = torch.cat([self.chat_history_ids, new_user_input_ids], dim=-1)
else:
bot_input_ids = new_user_input_ids
# Generar respuesta
with torch.no_grad():
chat_history_ids = self.model_loader.model.generate(
bot_input_ids,
max_length=max_length,
num_beams=1,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
pad_token_id=self.model_loader.tokenizer.eos_token_id,
attention_mask=torch.ones(bot_input_ids.shape, device=self.model_loader.device)
)
# Actualizar historial
self.chat_history_ids = chat_history_ids
# Decodificar solo la nueva respuesta
response = self.model_loader.tokenizer.decode(
chat_history_ids[:, bot_input_ids.shape[-1]:][0],
skip_special_tokens=True
)
return str(response).strip()
except Exception as e:
logger.error(f"Error en la generaci贸n: {str(e)}")
return f"Error al generar respuesta: {str(e)}"
def generate_text(self, prompt: str, **kwargs) -> str:
"""
Genera texto continuando un prompt (sin historial de chat)
Args:
prompt (str): Texto inicial
**kwargs: Par谩metros de generaci贸n
Returns:
str: Texto generado
"""
if not self.model_loader.is_loaded():
return "Error: No hay modelo cargado"
try:
# Par谩metros por defecto
max_length = kwargs.get('max_length', 100)
temperature = kwargs.get('temperature', 0.8)
top_p = kwargs.get('top_p', 0.9)
do_sample = kwargs.get('do_sample', True)
# Codificar el prompt
input_ids = self.model_loader.tokenizer.encode(
prompt,
return_tensors='pt'
).to(self.model_loader.device)
# Generar texto
with torch.no_grad():
output = self.model_loader.model.generate(
input_ids,
max_length=input_ids.shape[1] + max_length,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
pad_token_id=self.model_loader.tokenizer.eos_token_id,
attention_mask=torch.ones(input_ids.shape, device=self.model_loader.device)
)
# Decodificar solo el texto generado
generated_text = self.model_loader.tokenizer.decode(
output[0][input_ids.shape[1]:],
skip_special_tokens=True
)
return str(generated_text.strip())
except Exception as e:
logger.error(f"Error en la generaci贸n: {str(e)}")
return f"Error al generar texto: {str(e)}"
def reset_chat_history(self):
"""Reinicia el historial de chat"""
self.chat_history_ids = None
logger.info("Historial de chat reiniciado")
def get_generation_stats(self) -> Dict:
"""Retorna estad铆sticas de generaci贸n"""
if self.chat_history_ids is not None:
return {
"history_length": self.chat_history_ids.shape[1],
"device": str(self.chat_history_ids.device)
}
return {"history_length": 0, "device": "N/A"}
|