File size: 2,654 Bytes
ee44f2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Example inference for DANTE-Mosaic-3.5B.

Usage:
    python example_inference.py
    python example_inference.py --model YourOrg/DANTE-Mosaic-3.5B
    python example_inference.py --model ./local_path/

Run on a single A100 / RTX 4090 / H100. ~5.8 GB VRAM in BF16.
"""
from __future__ import annotations

import argparse
import time

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

PROMPTS = [
    ("MATH", "What is the derivative of f(x) = x^3 + 2x^2 - 5x + 1? Show step by step."),
    ("CODE", "Write a Python function that checks if a string is a palindrome. Include a docstring and edge cases."),
    ("LOGIC", "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost? Explain."),
    ("ITA", "Spiega cos'è il machine learning in termini semplici, adatti a uno studente delle superiori."),
]


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--model", default="./",
                   help="HF repo id or local path to the model directory")
    p.add_argument("--max-new-tokens", type=int, default=256)
    p.add_argument("--temperature", type=float, default=0.7)
    p.add_argument("--top-p", type=float, default=0.9)
    args = p.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading {args.model} on {device} ...")
    t0 = time.time()
    tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    ).eval()
    print(f"Loaded in {time.time()-t0:.1f}s "
          f"({sum(p.numel() for p in model.parameters())/1e9:.2f}B params)\n")

    for tag, prompt in PROMPTS:
        print("─" * 60)
        print(f"[{tag}] {prompt}\n")
        inputs = tok(prompt, return_tensors="pt").to(model.device)
        plen = inputs["input_ids"].shape[-1]
        t0 = time.time()
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=args.max_new_tokens,
                do_sample=True,
                temperature=args.temperature,
                top_p=args.top_p,
                repetition_penalty=1.1,
                pad_token_id=tok.eos_token_id,
            )
        new_toks = out.shape[-1] - plen
        elapsed = time.time() - t0
        text = tok.decode(out[0][plen:], skip_special_tokens=True).strip()
        print(text)
        print(f"\n  [{new_toks} tokens in {elapsed:.1f}s — {new_toks/elapsed:.1f} tok/s]\n")


if __name__ == "__main__":
    main()