Spaces:
Runtime error
Runtime error
File size: 2,656 Bytes
5d78646 7c68366 5d78646 7c68366 5d78646 ceb558e 5d78646 578774c 7c68366 5d78646 578774c 03b3c9a 5d78646 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
# Importamos las clases necesarias de la librer铆a Hugging Face Transformers.
# Usamos GPT2LMHeadModel porque incluye el "Language Model Head"
# necesario para tareas de generaci贸n de texto.
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import sys
# --- CONFIGURACI脫N EST脕NDAR ---
# MODEL_ID: Usamos 'gpt2', el identificador oficial y compatible.
# ESTA ES LA CORRECCI脫N CLAVE para evitar el error de compatibilidad.
MODEL_ID = 'gpt2'
PROMPT = "La arquitectura de la red neuronal Transformer se basa en un mecanismo de"
MAX_LENGTH = 70
TEMPERATURE = 0.8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Iniciando carga del modelo GPT-2: {MODEL_ID} en dispositivo: {DEVICE}")
def main():
"""Funci贸n principal para cargar el modelo y generar texto."""
try:
# 1. Cargar el Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_ID)
# 2. Cargar el Modelo con el 'Language Modeling Head'
model = GPT2LMHeadModel.from_pretrained(MODEL_ID).to(DEVICE)
# Configuramos el token de padding (necesario para la generaci贸n)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
# 3. Codificar el texto de inicio (prompt)
input_ids = tokenizer.encode(PROMPT, return_tensors='pt').to(DEVICE)
print("\n--- Generando Texto ---")
# 4. Generar la secuencia de texto
output = model.generate(
input_ids,
max_length=MAX_LENGTH,
num_return_sequences=1,
do_sample=True, # Permite sampling (generaci贸n m谩s creativa)
top_k=50, # Limita las opciones a las 50 m谩s probables
temperature=TEMPERATURE, # Controla la creatividad (m谩s alto = m谩s aleatorio)
pad_token_id=tokenizer.eos_token_id
)
# 5. Decodificar y mostrar el resultado
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"\nPrompt Original: {PROMPT}")
print("-" * (len(PROMPT) + 15))
print(f"Texto Generado Completo: {generated_text}")
print("\n隆Generaci贸n completada con 茅xito!")
except Exception as e:
print(f"\n[ERROR CR脥TICO] Fallo al cargar el modelo o generar texto.", file=sys.stderr)
print(f"Aseg煤rate de tener instalada la librer铆a 'transformers' y 'torch' (pip install transformers torch).", file=sys.stderr)
print(f"Detalle del error: {e}", file=sys.stderr)
if __name__ == "__main__":
main()
|