Spaces:
Sleeping
Sleeping
| from inference.onnx_inference import generate_text | |
| import argparse | |
| import onnxruntime as ort | |
| from inference.model import ByteTokenizer | |
| sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"] | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Inference with ONNX DiffTransformerLLM" | |
| ) | |
| parser.add_argument( | |
| "--onnx_path", type=str, default="models/small.onnx", help="Path to ONNX model" | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| default="<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n", | |
| help="Prompt for the model", | |
| ) | |
| parser.add_argument("--max_tokens", type=int, default=100, help="Max new tokens") | |
| parser.add_argument( | |
| "--temperature", type=float, default=0.7, help="Temperature for sampling" | |
| ) | |
| parser.add_argument("--top_k", type=int, default=1, help="Top-k for sampling") | |
| parser.add_argument( | |
| "--stop_sequence", type=str, action="append", help="Stop sequence(s)" | |
| ) | |
| # DRY sampling args | |
| parser.add_argument( | |
| "--dry_range", type=int, default=1024, help="Range for DRY sampling" | |
| ) | |
| parser.add_argument( | |
| "--dry_allowed_length", | |
| type=int, | |
| default=17, | |
| help="Allowed repeat length for DRY sampling", | |
| ) | |
| parser.add_argument( | |
| "--dry_base", type=float, default=1.1, help="Base for DRY penalty" | |
| ) | |
| parser.add_argument( | |
| "--dry_multiplier", type=float, default=0.0, help="Multiplier for DRY penalty" | |
| ) | |
| args = parser.parse_args() | |
| print(f"Loading ONNX model from {args.onnx_path}") | |
| session = ort.InferenceSession(args.onnx_path, providers=["CPUExecutionProvider"]) | |
| tokenizer = ByteTokenizer() | |
| sequence_breaker_ids = {tokenizer.im_start_id, tokenizer.im_end_id} | |
| for s in sequence_breaker_strings: | |
| # These are single-byte tokens, so encode will return a list with one ID | |
| sequence_breaker_ids.add(tokenizer.encode(s.encode("utf-8"))[0]) | |
| print(f"Prompt: {args.prompt}") | |
| print("--- Output ---") | |
| generated_text, tps = generate_text( | |
| session, | |
| tokenizer, | |
| args.prompt, | |
| max_new_tokens=args.max_tokens, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| stop_sequences=["<|im_end|>".encode("utf-8")], | |
| dry_sequence_breakers=sequence_breaker_ids, | |
| dry_range=args.dry_range, | |
| dry_allowed_length=args.dry_allowed_length, | |
| dry_base=args.dry_base, | |
| dry_multiplier=args.dry_multiplier, | |
| ) | |
| print(generated_text) | |
| print(generated_text.decode("utf-8", "ignore")) | |
| print("--------------") | |
| print(f"\nPerformance: {tps:.2f} tokens/second") | |
| if __name__ == "__main__": | |
| main() | |