#!/usr/bin/env python3 import argparse import os import sys import torch from transformers import AutoModelForCausalLM, AutoTokenizer def parse_args(): parser = argparse.ArgumentParser(description="Teste de inferência para eli5_clm-model (CLM)") parser.add_argument("--model_dir", type=str, default=".", help="Diretório do modelo (pasta que contém config.json, tokenizer, pesos, etc.)") parser.add_argument("--prompt", type=str, required=True, help="Texto de entrada para geração") parser.add_argument("--max_new_tokens", type=int, default=80, help="Máximo de novos tokens a gerar") parser.add_argument("--temperature", type=float, default=0.7, help="Temperatura para amostragem (criatividade)") parser.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus sampling)") parser.add_argument("--do_sample", type=lambda x: str(x).lower() in {"1","true","yes","y"}, default=True, help="Se verdadeiro, usa amostragem; se falso, greedy (padrao: true)") parser.add_argument("--seed", type=int, default=None, help="Semente para reprodutibilidade") parser.add_argument("--device", type=str, choices=["auto", "cpu", "cuda"], default="auto", help="Força dispositivo: auto/cpu/cuda") return parser.parse_args() def select_device(choice: str) -> torch.device: if choice == "cpu": return torch.device("cpu") if choice == "cuda": if torch.cuda.is_available(): return torch.device("cuda") print("[aviso] CUDA não disponível, usando CPU.") return torch.device("cpu") # auto if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def main(): args = parse_args() if args.seed is not None: torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) device = select_device(args.device) print(f"[info] Usando dispositivo: {device}") model_dir = os.path.abspath(args.model_dir) if not os.path.isdir(model_dir): print(f"[erro] Diretório do modelo não encontrado: {model_dir}") sys.exit(1) print("[info] Carregando tokenizer e modelo...") tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForCausalLM.from_pretrained(model_dir) model.to(device) model.eval() inputs = tokenizer(args.prompt, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} gen_kwargs = { "max_new_tokens": args.max_new_tokens, "do_sample": args.do_sample, } if args.do_sample: gen_kwargs.update({ "temperature": args.temperature, "top_p": args.top_p, }) print("[info] Gerando texto...") with torch.no_grad(): outputs = model.generate(**inputs, **gen_kwargs) full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) print("\n=== Saída completa ===\n") print(full_text) # Tentar extrair apenas a continuação gerada (se compatível com o tokenizer) try: prompt_len = len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)) print("\n=== Continuação gerada ===\n") print(full_text[prompt_len:]) except Exception: pass if __name__ == "__main__": main()