|
|
|
|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Teste de inferência para eli5_clm-model (CLM)") |
|
|
parser.add_argument("--model_dir", type=str, default=".", help="Diretório do modelo (pasta que contém config.json, tokenizer, pesos, etc.)") |
|
|
parser.add_argument("--prompt", type=str, required=True, help="Texto de entrada para geração") |
|
|
parser.add_argument("--max_new_tokens", type=int, default=80, help="Máximo de novos tokens a gerar") |
|
|
parser.add_argument("--temperature", type=float, default=0.7, help="Temperatura para amostragem (criatividade)") |
|
|
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus sampling)") |
|
|
parser.add_argument("--do_sample", type=lambda x: str(x).lower() in {"1","true","yes","y"}, default=True, |
|
|
help="Se verdadeiro, usa amostragem; se falso, greedy (padrao: true)") |
|
|
parser.add_argument("--seed", type=int, default=None, help="Semente para reprodutibilidade") |
|
|
parser.add_argument("--device", type=str, choices=["auto", "cpu", "cuda"], default="auto", |
|
|
help="Força dispositivo: auto/cpu/cuda") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def select_device(choice: str) -> torch.device: |
|
|
if choice == "cpu": |
|
|
return torch.device("cpu") |
|
|
if choice == "cuda": |
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda") |
|
|
print("[aviso] CUDA não disponível, usando CPU.") |
|
|
return torch.device("cpu") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda") |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
if args.seed is not None: |
|
|
torch.manual_seed(args.seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(args.seed) |
|
|
|
|
|
device = select_device(args.device) |
|
|
print(f"[info] Usando dispositivo: {device}") |
|
|
|
|
|
model_dir = os.path.abspath(args.model_dir) |
|
|
if not os.path.isdir(model_dir): |
|
|
print(f"[erro] Diretório do modelo não encontrado: {model_dir}") |
|
|
sys.exit(1) |
|
|
|
|
|
print("[info] Carregando tokenizer e modelo...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_dir) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
inputs = tokenizer(args.prompt, return_tensors="pt") |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
gen_kwargs = { |
|
|
"max_new_tokens": args.max_new_tokens, |
|
|
"do_sample": args.do_sample, |
|
|
} |
|
|
if args.do_sample: |
|
|
gen_kwargs.update({ |
|
|
"temperature": args.temperature, |
|
|
"top_p": args.top_p, |
|
|
}) |
|
|
|
|
|
print("[info] Gerando texto...") |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate(**inputs, **gen_kwargs) |
|
|
|
|
|
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
print("\n=== Saída completa ===\n") |
|
|
print(full_text) |
|
|
|
|
|
|
|
|
try: |
|
|
prompt_len = len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)) |
|
|
print("\n=== Continuação gerada ===\n") |
|
|
print(full_text[prompt_len:]) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|