File size: 3,377 Bytes
862bc7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python3
import argparse
import os
import sys

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def parse_args():
    parser = argparse.ArgumentParser(description="Teste de inferência para eli5_clm-model (CLM)")
    parser.add_argument("--model_dir", type=str, default=".", help="Diretório do modelo (pasta que contém config.json, tokenizer, pesos, etc.)")
    parser.add_argument("--prompt", type=str, required=True, help="Texto de entrada para geração")
    parser.add_argument("--max_new_tokens", type=int, default=80, help="Máximo de novos tokens a gerar")
    parser.add_argument("--temperature", type=float, default=0.7, help="Temperatura para amostragem (criatividade)")
    parser.add_argument("--top_p", type=float, default=0.9, help="Top-p (nucleus sampling)")
    parser.add_argument("--do_sample", type=lambda x: str(x).lower() in {"1","true","yes","y"}, default=True,
                        help="Se verdadeiro, usa amostragem; se falso, greedy (padrao: true)")
    parser.add_argument("--seed", type=int, default=None, help="Semente para reprodutibilidade")
    parser.add_argument("--device", type=str, choices=["auto", "cpu", "cuda"], default="auto",
                        help="Força dispositivo: auto/cpu/cuda")
    return parser.parse_args()


def select_device(choice: str) -> torch.device:
    if choice == "cpu":
        return torch.device("cpu")
    if choice == "cuda":
        if torch.cuda.is_available():
            return torch.device("cuda")
        print("[aviso] CUDA não disponível, usando CPU.")
        return torch.device("cpu")
    # auto
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def main():
    args = parse_args()

    if args.seed is not None:
        torch.manual_seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(args.seed)

    device = select_device(args.device)
    print(f"[info] Usando dispositivo: {device}")

    model_dir = os.path.abspath(args.model_dir)
    if not os.path.isdir(model_dir):
        print(f"[erro] Diretório do modelo não encontrado: {model_dir}")
        sys.exit(1)

    print("[info] Carregando tokenizer e modelo...")
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForCausalLM.from_pretrained(model_dir)
    model.to(device)
    model.eval()

    inputs = tokenizer(args.prompt, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    gen_kwargs = {
        "max_new_tokens": args.max_new_tokens,
        "do_sample": args.do_sample,
    }
    if args.do_sample:
        gen_kwargs.update({
            "temperature": args.temperature,
            "top_p": args.top_p,
        })

    print("[info] Gerando texto...")
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)

    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("\n=== Saída completa ===\n")
    print(full_text)

    # Tentar extrair apenas a continuação gerada (se compatível com o tokenizer)
    try:
        prompt_len = len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True))
        print("\n=== Continuação gerada ===\n")
        print(full_text[prompt_len:])
    except Exception:
        pass


if __name__ == "__main__":
    main()