File size: 1,430 Bytes
2147ce8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from __future__ import annotations

import argparse
import sys
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from reframr.model import ReframrModel


def main() -> None:
    parser = argparse.ArgumentParser(description="Run Reframr-RFM-v1-Base locally.")
    parser.add_argument("--model", default=str(REPO_ROOT / "model.safetensors"))
    parser.add_argument("--prompt", default="Who are you, and what makes Reframr different?")
    parser.add_argument("--system", default="")
    parser.add_argument("--max-tokens", type=int, default=90)
    parser.add_argument("--temperature", type=float, default=0.92)
    parser.add_argument("--top-k", type=int, default=72)
    parser.add_argument("--top-p", type=float, default=0.92)
    parser.add_argument("--repetition-penalty", type=float, default=1.18)
    args = parser.parse_args()

    context = args.prompt
    if args.system.strip():
        context = f"System instruction: {args.system.strip()}\nUser: {args.prompt}"

    model = ReframrModel.load(args.model)
    print(
        model.generate_text(
            context,
            max_tokens=args.max_tokens,
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
            repetition_penalty=args.repetition_penalty,
        )
    )


if __name__ == "__main__":
    main()