File size: 5,149 Bytes
22ca508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Generador de texto usando modelos GPT locales
"""
import torch
from typing import List, Dict
import logging

logger = logging.getLogger(__name__)

class TextGenerator:
    def __init__(self, model_loader):
        self.model_loader = model_loader
        self.chat_history_ids = None
        
    def generate_response(self, user_input: str, **kwargs) -> str:
        """
        Genera una respuesta basada en la entrada del usuario
        
        Args:
            user_input (str): Mensaje del usuario
            **kwargs: Par谩metros de generaci贸n (max_length, temperature, etc.)
        
        Returns:
            str: Respuesta generada
        """
        if not self.model_loader.is_loaded():
            return "Error: No hay modelo cargado"
        
        try:
            # Par谩metros por defecto
            max_length = kwargs.get('max_length', 512)
            temperature = kwargs.get('temperature', 0.7)
            top_p = kwargs.get('top_p', 0.9)
            do_sample = kwargs.get('do_sample', True)
            
            # Codificar la entrada del usuario
            new_user_input_ids = self.model_loader.tokenizer.encode(
                user_input + self.model_loader.tokenizer.eos_token, 
                return_tensors='pt'
            ).to(self.model_loader.device)
            
            # Concatenar con el historial de chat
            if self.chat_history_ids is not None:
                bot_input_ids = torch.cat([self.chat_history_ids, new_user_input_ids], dim=-1)
            else:
                bot_input_ids = new_user_input_ids
            
            # Generar respuesta
            with torch.no_grad():
                chat_history_ids = self.model_loader.model.generate(
                    bot_input_ids,
                    max_length=max_length,
                    num_beams=1,
                    do_sample=do_sample,
                    temperature=temperature,
                    top_p=top_p,
                    pad_token_id=self.model_loader.tokenizer.eos_token_id,
                    attention_mask=torch.ones(bot_input_ids.shape, device=self.model_loader.device)
                )
            
            # Actualizar historial
            self.chat_history_ids = chat_history_ids
            
            # Decodificar solo la nueva respuesta
            response = self.model_loader.tokenizer.decode(
                chat_history_ids[:, bot_input_ids.shape[-1]:][0], 
                skip_special_tokens=True
            )
            
            return str(response).strip()
            
        except Exception as e:
            logger.error(f"Error en la generaci贸n: {str(e)}")
            return f"Error al generar respuesta: {str(e)}"
    
    def generate_text(self, prompt: str, **kwargs) -> str:
        """
        Genera texto continuando un prompt (sin historial de chat)
        
        Args:
            prompt (str): Texto inicial
            **kwargs: Par谩metros de generaci贸n
        
        Returns:
            str: Texto generado
        """
        if not self.model_loader.is_loaded():
            return "Error: No hay modelo cargado"
        
        try:
            # Par谩metros por defecto
            max_length = kwargs.get('max_length', 100)
            temperature = kwargs.get('temperature', 0.8)
            top_p = kwargs.get('top_p', 0.9)
            do_sample = kwargs.get('do_sample', True)
            
            # Codificar el prompt
            input_ids = self.model_loader.tokenizer.encode(
                prompt, 
                return_tensors='pt'
            ).to(self.model_loader.device)
            
            # Generar texto
            with torch.no_grad():
                output = self.model_loader.model.generate(
                    input_ids,
                    max_length=input_ids.shape[1] + max_length,
                    do_sample=do_sample,
                    temperature=temperature,
                    top_p=top_p,
                    pad_token_id=self.model_loader.tokenizer.eos_token_id,
                    attention_mask=torch.ones(input_ids.shape, device=self.model_loader.device)
                )
            
            # Decodificar solo el texto generado
            generated_text = self.model_loader.tokenizer.decode(
                output[0][input_ids.shape[1]:], 
                skip_special_tokens=True
            )
            
            return str(generated_text.strip())
            
        except Exception as e:
            logger.error(f"Error en la generaci贸n: {str(e)}")
            return f"Error al generar texto: {str(e)}"
    
    def reset_chat_history(self):
        """Reinicia el historial de chat"""
        self.chat_history_ids = None
        logger.info("Historial de chat reiniciado")
    
    def get_generation_stats(self) -> Dict:
        """Retorna estad铆sticas de generaci贸n"""
        if self.chat_history_ids is not None:
            return {
                "history_length": self.chat_history_ids.shape[1],
                "device": str(self.chat_history_ids.device)
            }
        return {"history_length": 0, "device": "N/A"}