Andro0s's picture
Update app.py
5d78646 verified
raw
history blame
2.66 kB
# Importamos las clases necesarias de la librería Hugging Face Transformers.
# Usamos GPT2LMHeadModel porque incluye el "Language Model Head"
# necesario para tareas de generación de texto.
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import sys
# --- CONFIGURACIÓN ESTÁNDAR ---
# MODEL_ID: Usamos 'gpt2', el identificador oficial y compatible.
# ESTA ES LA CORRECCIÓN CLAVE para evitar el error de compatibilidad.
MODEL_ID = 'gpt2'
PROMPT = "La arquitectura de la red neuronal Transformer se basa en un mecanismo de"
MAX_LENGTH = 70
TEMPERATURE = 0.8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Iniciando carga del modelo GPT-2: {MODEL_ID} en dispositivo: {DEVICE}")
def main():
"""Función principal para cargar el modelo y generar texto."""
try:
# 1. Cargar el Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_ID)
# 2. Cargar el Modelo con el 'Language Modeling Head'
model = GPT2LMHeadModel.from_pretrained(MODEL_ID).to(DEVICE)
# Configuramos el token de padding (necesario para la generación)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
# 3. Codificar el texto de inicio (prompt)
input_ids = tokenizer.encode(PROMPT, return_tensors='pt').to(DEVICE)
print("\n--- Generando Texto ---")
# 4. Generar la secuencia de texto
output = model.generate(
input_ids,
max_length=MAX_LENGTH,
num_return_sequences=1,
do_sample=True, # Permite sampling (generación más creativa)
top_k=50, # Limita las opciones a las 50 más probables
temperature=TEMPERATURE, # Controla la creatividad (más alto = más aleatorio)
pad_token_id=tokenizer.eos_token_id
)
# 5. Decodificar y mostrar el resultado
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"\nPrompt Original: {PROMPT}")
print("-" * (len(PROMPT) + 15))
print(f"Texto Generado Completo: {generated_text}")
print("\n¡Generación completada con éxito!")
except Exception as e:
print(f"\n[ERROR CRÍTICO] Fallo al cargar el modelo o generar texto.", file=sys.stderr)
print(f"Asegúrate de tener instalada la librería 'transformers' y 'torch' (pip install transformers torch).", file=sys.stderr)
print(f"Detalle del error: {e}", file=sys.stderr)
if __name__ == "__main__":
main()