eli5_clm-model / test_inference.py
carlosdelfino's picture
End of training
862bc7d verified
#!/usr/bin/env python3
import argparse
import os
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def parse_args():
parser = argparse.ArgumentParser(description="Teste de inferência para eli5_clm-model (CLM)")
parser.add_argument("--model_dir", type=str, default=".", help="Diretório do modelo (pasta que contém config.json, tokenizer, pesos, etc.)")
parser.add_argument("--prompt", type=str, required=True, help="Texto de entrada para geração")
parser.add_argument("--max_new_tokens", type=int, default=80, help="Máximo de novos tokens a gerar")
parser.add_argument("--temperature", type=float, default=0.7, help="Temperatura para amostragem (criatividade)")
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus sampling)")
parser.add_argument("--do_sample", type=lambda x: str(x).lower() in {"1","true","yes","y"}, default=True,
help="Se verdadeiro, usa amostragem; se falso, greedy (padrao: true)")
parser.add_argument("--seed", type=int, default=None, help="Semente para reprodutibilidade")
parser.add_argument("--device", type=str, choices=["auto", "cpu", "cuda"], default="auto",
help="Força dispositivo: auto/cpu/cuda")
return parser.parse_args()
def select_device(choice: str) -> torch.device:
if choice == "cpu":
return torch.device("cpu")
if choice == "cuda":
if torch.cuda.is_available():
return torch.device("cuda")
print("[aviso] CUDA não disponível, usando CPU.")
return torch.device("cpu")
# auto
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def main():
args = parse_args()
if args.seed is not None:
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
device = select_device(args.device)
print(f"[info] Usando dispositivo: {device}")
model_dir = os.path.abspath(args.model_dir)
if not os.path.isdir(model_dir):
print(f"[erro] Diretório do modelo não encontrado: {model_dir}")
sys.exit(1)
print("[info] Carregando tokenizer e modelo...")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForCausalLM.from_pretrained(model_dir)
model.to(device)
model.eval()
inputs = tokenizer(args.prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
gen_kwargs = {
"max_new_tokens": args.max_new_tokens,
"do_sample": args.do_sample,
}
if args.do_sample:
gen_kwargs.update({
"temperature": args.temperature,
"top_p": args.top_p,
})
print("[info] Gerando texto...")
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("\n=== Saída completa ===\n")
print(full_text)
# Tentar extrair apenas a continuação gerada (se compatível com o tokenizer)
try:
prompt_len = len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True))
print("\n=== Continuação gerada ===\n")
print(full_text[prompt_len:])
except Exception:
pass
if __name__ == "__main__":
main()