File size: 6,999 Bytes
cb2d3e5 be18c4a cb2d3e5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | from inference.inference import (
force_CPU,
generate_text_stream,
list_checkpoints,
load_model,
)
import argparse
import torch
from inference.model import ByteTokenizer
import os
import sys
def main():
parser = argparse.ArgumentParser(
description="Text generation with DiffAttention LLM",
formatter_class=argparse.RawTextHelpFormatter,
)
# Generation mode arguments
parser.add_argument(
"--prompt",
type=str,
default="",
help="Run in single-shot mode with the given prompt.",
)
parser.add_argument(
"-c", "--chat", action="store_true", help="Run in interactive chat mode."
)
# Chat mode arguments
parser.add_argument(
"--system",
type=str,
default="You are a helpful chatbot.",
help="System prompt for chat mode.",
)
parser.add_argument(
"--user_role",
type=str,
default="user",
help="Role name for the user in chat mode.",
)
parser.add_argument(
"--assistant_role",
type=str,
default="assistant",
help="Role name for the assistant in chat mode.",
)
# Common arguments
parser.add_argument(
"--checkpoint",
type=str,
default="model.pt",
help="Path to the checkpoint file.",
)
parser.add_argument(
"--stop",
nargs="+",
default=[],
help='One or more stop sequences. e.g. --stop "world" """',
)
parser.add_argument(
"--max_tokens",
type=int,
default=512,
help="Maximum number of new tokens to generate.",
)
parser.add_argument(
"--temperature", type=float, default=0.35, help="Sampling temperature."
)
parser.add_argument(
"--top_k",
type=int,
default=7,
help="Top-k sampling parameter (0 to disable).",
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.35,
help="Repetition penalty (1.0 for no penalty).",
)
parser.add_argument(
"--list_checkpoints",
action="store_true",
help="List available checkpoints and exit.",
)
args = parser.parse_args()
if not args.prompt and not args.chat and not args.list_checkpoints:
parser.print_help()
sys.exit(
"\nError: Either --prompt, --chat, or --list_checkpoints must be specified."
)
# List checkpoints if requested
if args.list_checkpoints:
print("Available checkpoints:")
checkpoints = list_checkpoints()
if not checkpoints:
print("No checkpoints found.")
for i, ckpt in enumerate(checkpoints):
print(f"{i+1}. {ckpt}")
return
checkpoint_path = args.checkpoint
if not os.path.exists(checkpoint_path):
print(f"Checkpoint file not found: {checkpoint_path}")
print("Searching for latest checkpoint in 'checkpoints/' directory...")
checkpoints = list_checkpoints()
if not checkpoints:
sys.exit(
"No checkpoints found. Please train a model or specify a valid path."
)
end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
if end_checkpoints:
latest_checkpoint = max(end_checkpoints)
else:
latest_checkpoint = max(checkpoints)
checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
print(f"Using latest checkpoint: {checkpoint_path}")
# Set device
if torch.backends.mps.is_available() and not force_CPU:
device = torch.device("mps")
else:
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
)
print(f"Using device: {device}")
tokenizer = ByteTokenizer()
# Load model
model = load_model(checkpoint_path, device)
# --- Mode Handling ---
if args.chat:
stop_sequences = args.stop + ["<|im_end|>"]
history = f"<|im_start|>system\n{args.system}<|im_end|>\n"
print("\n--- Interactive Chat ---")
print(f"System Prompt: {args.system}")
print("Type 'exit' or 'quit' to end the session.")
print("-" * 26)
while True:
try:
user_prompt_display = f"<|im_start|>{args.user_role}\n"
user_input = input(user_prompt_display)
if user_input.lower() in ["exit", "quit"]:
break
prompt = (
history
+ f"<|im_start|>{args.user_role}\n{user_input}<|im_end|>\n"
+ f"<|im_start|>{args.assistant_role}\n"
)
print(f"<|im_start|>{args.assistant_role}")
sys.stdout.flush()
generated_text_parts = []
for chunk in generate_text_stream(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
device=device,
stop_sequences=stop_sequences,
):
print(chunk, end="", flush=True)
generated_text_parts.append(chunk)
generated_text = "".join(generated_text_parts)
history += (
f"<|im_start|>{args.user_role}\n{user_input}<|im_end|>\n"
+ f"<|im_start|>{args.assistant_role}\n{generated_text}<|im_end|>\n"
)
print() # Newline after assistant output
except (KeyboardInterrupt, EOFError):
print("\nExiting chat.")
break
else:
print(f"\nGenerating text with prompt: '{args.prompt}'")
print(
f"Parameters: temp={args.temperature}, top_k={args.top_k}, repetition_penalty={args.repetition_penalty}"
)
print("\n--- Generation Start ---")
generated_text_parts = []
for chunk in generate_text_stream(
model=model,
tokenizer=tokenizer,
prompt=args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
device=device,
stop_sequences=args.stop,
):
print(chunk, end="", flush=True)
generated_text_parts.append(chunk)
print("\n--- Generation End ---")
generated_text = "".join(generated_text_parts)
full_text = args.prompt + generated_text
print("\n\nFull generated text (for reference):")
print("-" * 40)
print(full_text)
print("-" * 40)
if __name__ == "__main__":
main()
|