File size: 2,656 Bytes
5d78646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c68366
5d78646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c68366
5d78646
 
ceb558e
5d78646
 
 
 
578774c
7c68366
5d78646
 
 
578774c
03b3c9a
5d78646
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Importamos las clases necesarias de la librer铆a Hugging Face Transformers.
# Usamos GPT2LMHeadModel porque incluye el "Language Model Head"
# necesario para tareas de generaci贸n de texto.
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import sys

# --- CONFIGURACI脫N EST脕NDAR ---
# MODEL_ID: Usamos 'gpt2', el identificador oficial y compatible.
# ESTA ES LA CORRECCI脫N CLAVE para evitar el error de compatibilidad.
MODEL_ID = 'gpt2' 
PROMPT = "La arquitectura de la red neuronal Transformer se basa en un mecanismo de"
MAX_LENGTH = 70
TEMPERATURE = 0.8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Iniciando carga del modelo GPT-2: {MODEL_ID} en dispositivo: {DEVICE}")

def main():
    """Funci贸n principal para cargar el modelo y generar texto."""
    try:
        # 1. Cargar el Tokenizer
        tokenizer = GPT2Tokenizer.from_pretrained(MODEL_ID)
        
        # 2. Cargar el Modelo con el 'Language Modeling Head'
        model = GPT2LMHeadModel.from_pretrained(MODEL_ID).to(DEVICE)
        
        # Configuramos el token de padding (necesario para la generaci贸n)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.eos_token_id

        # 3. Codificar el texto de inicio (prompt)
        input_ids = tokenizer.encode(PROMPT, return_tensors='pt').to(DEVICE)

        print("\n--- Generando Texto ---")
        
        # 4. Generar la secuencia de texto
        output = model.generate(
            input_ids,
            max_length=MAX_LENGTH,
            num_return_sequences=1,
            do_sample=True,             # Permite sampling (generaci贸n m谩s creativa)
            top_k=50,                   # Limita las opciones a las 50 m谩s probables
            temperature=TEMPERATURE,    # Controla la creatividad (m谩s alto = m谩s aleatorio)
            pad_token_id=tokenizer.eos_token_id
        )

        # 5. Decodificar y mostrar el resultado
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        
        print(f"\nPrompt Original: {PROMPT}")
        print("-" * (len(PROMPT) + 15))
        print(f"Texto Generado Completo: {generated_text}")
        print("\n隆Generaci贸n completada con 茅xito!")

    except Exception as e:
        print(f"\n[ERROR CR脥TICO] Fallo al cargar el modelo o generar texto.", file=sys.stderr)
        print(f"Aseg煤rate de tener instalada la librer铆a 'transformers' y 'torch' (pip install transformers torch).", file=sys.stderr)
        print(f"Detalle del error: {e}", file=sys.stderr)

if __name__ == "__main__":
    main()