| | |
| |
|
| | import argparse |
| | import json |
| | import os |
| | import random |
| | import sys |
| |
|
| |
|
| | def set_seed(seed: int): |
| | """Set random seeds for reproducibility.""" |
| | random.seed(seed) |
| | try: |
| | import numpy as np |
| | np.random.seed(seed) |
| | except ImportError: |
| | pass |
| | try: |
| | import torch |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed) |
| | except ImportError: |
| | pass |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--model_path", "-m", default=".", help="Path to converted model") |
| | parser.add_argument( |
| | "--prompts", "-p", default="test_prompts.json", |
| | help="Path to JSON file with list of prompt strings (default: hf_conversion/test_prompts.json)") |
| | parser.add_argument( |
| | "--seed", "-s", type=int, default=0, |
| | help="Random seed for reproducible generation (default: None, non-deterministic)") |
| | parser.add_argument( |
| | "--max_new_tokens", type=int, default=None, |
| | help="Max tokens to generate (default: 50)") |
| | parser.add_argument( |
| | "--max_new_sents", type=int, default=None, |
| | help="Max sentences in decoded output (default: pipeline default)") |
| | args = parser.parse_args() |
| |
|
| | if args.seed is not None: |
| | set_seed(args.seed) |
| | print(f"Random seed set to {args.seed} for reproducibility") |
| |
|
| | if not os.path.isdir(args.model_path): |
| | print(f"Error: Model path {args.model_path} does not exist.") |
| | sys.exit(1) |
| |
|
| | prompts_path = args.prompts |
| | if prompts_path is None: |
| | prompts_path = os.path.join(os.path.dirname( |
| | os.path.abspath(__file__)), "test_prompts.json") |
| | if not os.path.isfile(prompts_path): |
| | print(f"Error: Prompts file {prompts_path} does not exist.") |
| | sys.exit(1) |
| |
|
| | print("Loading model and tokenizer...") |
| | from transformers import AutoModelForCausalLM |
| |
|
| | |
| | model_path = os.path.abspath(args.model_path) |
| | from rnnlm_model import ( |
| | RNNLMConfig, |
| | RNNLMForCausalLM, |
| | RNNLMTokenizer, |
| | RNNLMTextGenerationPipeline, |
| | ) |
| | from transformers import AutoConfig |
| | AutoConfig.register("rnnlm", RNNLMConfig) |
| | AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM) |
| |
|
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_path, trust_remote_code=True) |
| | tokenizer = RNNLMTokenizer.from_pretrained(model_path) |
| |
|
| | print("Creating RNNLMTextGenerationPipeline (with entity adaptation)...") |
| | pipe = RNNLMTextGenerationPipeline( |
| | model=model, |
| | tokenizer=tokenizer, |
| | ) |
| |
|
| | with open(prompts_path) as f: |
| | test_prompts = json.load(f) |
| |
|
| | base_kwargs = dict( |
| | max_new_tokens=args.max_new_tokens if args.max_new_tokens is not None else 50, |
| | do_sample=True, |
| | temperature=1.0, |
| | pad_token_id=tokenizer.pad_token_id, |
| | ) |
| | if args.max_new_sents is not None: |
| | base_kwargs["max_new_sents"] = args.max_new_sents |
| |
|
| | def run_tests(kwargs): |
| | for i, prompt in enumerate(test_prompts): |
| | print(f"\n [{i + 1}/{len(test_prompts)}]") |
| | print(f" PROMPT: ``{prompt}``") |
| | output = pipe(prompt, **kwargs) |
| | print(f" GENERATED: ``{output[0]['generated_text']}``") |
| |
|
| | |
| | print("\n--- Test 1: Basic generation (default params) ---") |
| | run_tests(base_kwargs) |
| |
|
| | |
| | print("\n--- Test 2: max_new_tokens=20 ---") |
| | short_kwargs = {**base_kwargs, "max_new_tokens": 20} |
| | run_tests(short_kwargs) |
| |
|
| | |
| | print("\n--- Test 3: max_new_sents=2 ---") |
| | sents_kwargs = {**base_kwargs, "max_new_sents": 2} |
| | run_tests(sents_kwargs) |
| |
|
| | |
| | print("\n--- Test 4: max_new_sents=1 ---") |
| | sents1_kwargs = {**base_kwargs, "max_new_sents": 1} |
| | run_tests(sents1_kwargs) |
| |
|
| | |
| | print("\n--- Test 5: do_sample=False ---") |
| | greedy_kwargs = {**base_kwargs, "do_sample": False} |
| | run_tests(greedy_kwargs) |
| |
|
| | |
| | print("\n--- Test 6: temperature=0.3 ---") |
| | low_temp_kwargs = {**base_kwargs, "temperature": 0.3} |
| | run_tests(low_temp_kwargs) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|