File size: 4,414 Bytes
c886682 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | #!/usr/bin/env python3
import argparse
import json
import os
import random
import sys
def set_seed(seed: int):
"""Set random seeds for reproducibility."""
random.seed(seed)
try:
import numpy as np
np.random.seed(seed)
except ImportError:
pass
try:
import torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
except ImportError:
pass
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path", "-m", default=".", help="Path to converted model")
parser.add_argument(
"--prompts", "-p", default="test_prompts.json",
help="Path to JSON file with list of prompt strings (default: hf_conversion/test_prompts.json)")
parser.add_argument(
"--seed", "-s", type=int, default=0,
help="Random seed for reproducible generation (default: None, non-deterministic)")
parser.add_argument(
"--max_new_tokens", type=int, default=None,
help="Max tokens to generate (default: 50)")
parser.add_argument(
"--max_new_sents", type=int, default=None,
help="Max sentences in decoded output (default: pipeline default)")
args = parser.parse_args()
if args.seed is not None:
set_seed(args.seed)
print(f"Random seed set to {args.seed} for reproducibility")
if not os.path.isdir(args.model_path):
print(f"Error: Model path {args.model_path} does not exist.")
sys.exit(1)
prompts_path = args.prompts
if prompts_path is None:
prompts_path = os.path.join(os.path.dirname(
os.path.abspath(__file__)), "test_prompts.json")
if not os.path.isfile(prompts_path):
print(f"Error: Prompts file {prompts_path} does not exist.")
sys.exit(1)
print("Loading model and tokenizer...")
from transformers import AutoModelForCausalLM
# Register custom model and load tokenizer directly (AutoTokenizer doesn't know RNNLMTokenizer)
model_path = os.path.abspath(args.model_path)
from rnnlm_model import (
RNNLMConfig,
RNNLMForCausalLM,
RNNLMTokenizer,
RNNLMTextGenerationPipeline,
)
from transformers import AutoConfig
AutoConfig.register("rnnlm", RNNLMConfig)
AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM)
model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True)
tokenizer = RNNLMTokenizer.from_pretrained(model_path)
print("Creating RNNLMTextGenerationPipeline (with entity adaptation)...")
pipe = RNNLMTextGenerationPipeline(
model=model,
tokenizer=tokenizer,
)
with open(prompts_path) as f:
test_prompts = json.load(f)
base_kwargs = dict(
max_new_tokens=args.max_new_tokens if args.max_new_tokens is not None else 50,
do_sample=True,
temperature=1.0,
pad_token_id=tokenizer.pad_token_id,
)
if args.max_new_sents is not None:
base_kwargs["max_new_sents"] = args.max_new_sents
def run_tests(kwargs):
for i, prompt in enumerate(test_prompts):
print(f"\n [{i + 1}/{len(test_prompts)}]")
print(f" PROMPT: ``{prompt}``")
output = pipe(prompt, **kwargs)
print(f" GENERATED: ``{output[0]['generated_text']}``")
# Test 1: Basic generation with default params
print("\n--- Test 1: Basic generation (default params) ---")
run_tests(base_kwargs)
# Test 2: max_new_tokens=20
print("\n--- Test 2: max_new_tokens=20 ---")
short_kwargs = {**base_kwargs, "max_new_tokens": 20}
run_tests(short_kwargs)
# Test 3: max_new_sents=2
print("\n--- Test 3: max_new_sents=2 ---")
sents_kwargs = {**base_kwargs, "max_new_sents": 2}
run_tests(sents_kwargs)
# Test 4: max_new_sents=1
print("\n--- Test 4: max_new_sents=1 ---")
sents1_kwargs = {**base_kwargs, "max_new_sents": 1}
run_tests(sents1_kwargs)
# Test 5: do_sample=False (greedy decoding)
print("\n--- Test 5: do_sample=False ---")
greedy_kwargs = {**base_kwargs, "do_sample": False}
run_tests(greedy_kwargs)
# Test 6: temperature=0.3
print("\n--- Test 6: temperature=0.3 ---")
low_temp_kwargs = {**base_kwargs, "temperature": 0.3}
run_tests(low_temp_kwargs)
if __name__ == "__main__":
main()
|