File size: 4,414 Bytes
c886682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python3

import argparse
import json
import os
import random
import sys


def set_seed(seed: int):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    try:
        import numpy as np
        np.random.seed(seed)
    except ImportError:
        pass
    try:
        import torch
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    except ImportError:
        pass


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path", "-m", default=".", help="Path to converted model")
    parser.add_argument(
        "--prompts", "-p", default="test_prompts.json",
        help="Path to JSON file with list of prompt strings (default: hf_conversion/test_prompts.json)")
    parser.add_argument(
        "--seed", "-s", type=int, default=0,
        help="Random seed for reproducible generation (default: None, non-deterministic)")
    parser.add_argument(
        "--max_new_tokens", type=int, default=None,
        help="Max tokens to generate (default: 50)")
    parser.add_argument(
        "--max_new_sents", type=int, default=None,
        help="Max sentences in decoded output (default: pipeline default)")
    args = parser.parse_args()

    if args.seed is not None:
        set_seed(args.seed)
        print(f"Random seed set to {args.seed} for reproducibility")

    if not os.path.isdir(args.model_path):
        print(f"Error: Model path {args.model_path} does not exist.")
        sys.exit(1)

    prompts_path = args.prompts
    if prompts_path is None:
        prompts_path = os.path.join(os.path.dirname(
            os.path.abspath(__file__)), "test_prompts.json")
    if not os.path.isfile(prompts_path):
        print(f"Error: Prompts file {prompts_path} does not exist.")
        sys.exit(1)

    print("Loading model and tokenizer...")
    from transformers import AutoModelForCausalLM

    # Register custom model and load tokenizer directly (AutoTokenizer doesn't know RNNLMTokenizer)
    model_path = os.path.abspath(args.model_path)
    from rnnlm_model import (
        RNNLMConfig,
        RNNLMForCausalLM,
        RNNLMTokenizer,
        RNNLMTextGenerationPipeline,
    )
    from transformers import AutoConfig
    AutoConfig.register("rnnlm", RNNLMConfig)
    AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM)

    model = AutoModelForCausalLM.from_pretrained(
        model_path, trust_remote_code=True)
    tokenizer = RNNLMTokenizer.from_pretrained(model_path)

    print("Creating RNNLMTextGenerationPipeline (with entity adaptation)...")
    pipe = RNNLMTextGenerationPipeline(
        model=model,
        tokenizer=tokenizer,
    )

    with open(prompts_path) as f:
        test_prompts = json.load(f)

    base_kwargs = dict(
        max_new_tokens=args.max_new_tokens if args.max_new_tokens is not None else 50,
        do_sample=True,
        temperature=1.0,
        pad_token_id=tokenizer.pad_token_id,
    )
    if args.max_new_sents is not None:
        base_kwargs["max_new_sents"] = args.max_new_sents

    def run_tests(kwargs):
        for i, prompt in enumerate(test_prompts):
            print(f"\n  [{i + 1}/{len(test_prompts)}]")
            print(f"  PROMPT: ``{prompt}``")
            output = pipe(prompt, **kwargs)
            print(f"  GENERATED: ``{output[0]['generated_text']}``")

    # Test 1: Basic generation with default params
    print("\n--- Test 1: Basic generation (default params) ---")
    run_tests(base_kwargs)

    # Test 2: max_new_tokens=20
    print("\n--- Test 2: max_new_tokens=20 ---")
    short_kwargs = {**base_kwargs, "max_new_tokens": 20}
    run_tests(short_kwargs)

    # Test 3: max_new_sents=2
    print("\n--- Test 3: max_new_sents=2 ---")
    sents_kwargs = {**base_kwargs, "max_new_sents": 2}
    run_tests(sents_kwargs)

    # Test 4: max_new_sents=1
    print("\n--- Test 4: max_new_sents=1 ---")
    sents1_kwargs = {**base_kwargs, "max_new_sents": 1}
    run_tests(sents1_kwargs)

    # Test 5: do_sample=False (greedy decoding)
    print("\n--- Test 5: do_sample=False ---")
    greedy_kwargs = {**base_kwargs, "do_sample": False}
    run_tests(greedy_kwargs)

    # Test 6: temperature=0.3
    print("\n--- Test 6: temperature=0.3 ---")
    low_temp_kwargs = {**base_kwargs, "temperature": 0.3}
    run_tests(low_temp_kwargs)


if __name__ == "__main__":
    main()