Text Generation
Safetensors
English
hudsongouge commited on
Commit
cb2d3e5
·
verified ·
1 Parent(s): b8d41ce

Upload run.py

Browse files
Files changed (1) hide show
  1. run.py +226 -0
run.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference.inference import (
2
+ force_CPU,
3
+ generate_text_stream,
4
+ list_checkpoints,
5
+ load_model,
6
+ )
7
+ import argparse
8
+ import torch
9
+ from inference.model import ByteTokenizer
10
+ import os
11
+ import sys
12
+
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser(
16
+ description="Text generation with DiffAttention LLM",
17
+ formatter_class=argparse.RawTextHelpFormatter,
18
+ )
19
+ # Generation mode arguments
20
+ parser.add_argument(
21
+ "--prompt",
22
+ type=str,
23
+ default="",
24
+ help="Run in single-shot mode with the given prompt.",
25
+ )
26
+ parser.add_argument(
27
+ "-c", "--chat", action="store_true", help="Run in interactive chat mode."
28
+ )
29
+
30
+ # Chat mode arguments
31
+ parser.add_argument(
32
+ "--system",
33
+ type=str,
34
+ default="You are a helpful chatbot.",
35
+ help="System prompt for chat mode.",
36
+ )
37
+ parser.add_argument(
38
+ "--user_role",
39
+ type=str,
40
+ default="user",
41
+ help="Role name for the user in chat mode.",
42
+ )
43
+ parser.add_argument(
44
+ "--assistant_role",
45
+ type=str,
46
+ default="assistant",
47
+ help="Role name for the assistant in chat mode.",
48
+ )
49
+
50
+ # Common arguments
51
+ parser.add_argument(
52
+ "--checkpoint",
53
+ type=str,
54
+ default="model.pt",
55
+ help="Path to the checkpoint file.",
56
+ )
57
+ parser.add_argument(
58
+ "--stop",
59
+ nargs="+",
60
+ default=[],
61
+ help='One or more stop sequences. e.g. --stop "world" """',
62
+ )
63
+ parser.add_argument(
64
+ "--max_tokens",
65
+ type=int,
66
+ default=512,
67
+ help="Maximum number of new tokens to generate.",
68
+ )
69
+ parser.add_argument(
70
+ "--temperature", type=float, default=0.35, help="Sampling temperature."
71
+ )
72
+ parser.add_argument(
73
+ "--top_k",
74
+ type=int,
75
+ default=7,
76
+ help="Top-k sampling parameter (0 to disable).",
77
+ )
78
+ parser.add_argument(
79
+ "--repetition_penalty",
80
+ type=float,
81
+ default=1.35,
82
+ help="Repetition penalty (1.0 for no penalty).",
83
+ )
84
+ parser.add_argument(
85
+ "--list_checkpoints",
86
+ action="store_true",
87
+ help="List available checkpoints and exit.",
88
+ )
89
+ args = parser.parse_args()
90
+
91
+ if not args.prompt and not args.chat and not args.list_checkpoints:
92
+ parser.print_help()
93
+ sys.exit(
94
+ "\nError: Either --prompt, --chat, or --list_checkpoints must be specified."
95
+ )
96
+
97
+ # List checkpoints if requested
98
+ if args.list_checkpoints:
99
+ print("Available checkpoints:")
100
+ checkpoints = list_checkpoints()
101
+ if not checkpoints:
102
+ print("No checkpoints found.")
103
+ for i, ckpt in enumerate(checkpoints):
104
+ print(f"{i+1}. {ckpt}")
105
+ return
106
+
107
+ checkpoint_path = args.checkpoint
108
+ if not os.path.exists(checkpoint_path):
109
+ print(f"Checkpoint file not found: {checkpoint_path}")
110
+ print("Searching for latest checkpoint in 'checkpoints/' directory...")
111
+ checkpoints = list_checkpoints()
112
+ if not checkpoints:
113
+ sys.exit(
114
+ "No checkpoints found. Please train a model or specify a valid path."
115
+ )
116
+
117
+ end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
118
+ if end_checkpoints:
119
+ latest_checkpoint = max(end_checkpoints)
120
+ else:
121
+ latest_checkpoint = max(checkpoints)
122
+
123
+ checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
124
+ print(f"Using latest checkpoint: {checkpoint_path}")
125
+
126
+ # Set device
127
+ if torch.backends.mps.is_available() and not force_CPU:
128
+ device = torch.device("mps")
129
+ else:
130
+ device = torch.device(
131
+ "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
132
+ )
133
+ print(f"Using device: {device}")
134
+
135
+ tokenizer = ByteTokenizer()
136
+
137
+ # Load model
138
+ model = load_model(checkpoint_path, device)
139
+
140
+ # --- Mode Handling ---
141
+ if args.chat:
142
+ stop_sequences = args.stop + ["<|im_end|>"]
143
+ history = f"<|im_start|>system\n{args.system}<|im_end|>\n"
144
+ print("\n--- Interactive Chat ---")
145
+ print(f"System Prompt: {args.system}")
146
+ print("Type 'exit' or 'quit' to end the session.")
147
+ print("-" * 26)
148
+
149
+ while True:
150
+ try:
151
+ user_prompt_display = f"<|im_start|>{args.user_role}\n"
152
+ user_input = input(user_prompt_display)
153
+
154
+ if user_input.lower() in ["exit", "quit"]:
155
+ break
156
+
157
+ prompt = (
158
+ history
159
+ + f"<|im_start|>{args.user_role}\n{user_input}<|im_end|>\n"
160
+ + f"<|im_start|>{args.assistant_role}\n```"
161
+ )
162
+
163
+ print(f"<|im_start|>{args.assistant_role}")
164
+ sys.stdout.flush()
165
+
166
+ generated_text_parts = []
167
+ for chunk in generate_text_stream(
168
+ model=model,
169
+ tokenizer=tokenizer,
170
+ prompt=prompt,
171
+ max_new_tokens=args.max_tokens,
172
+ temperature=args.temperature,
173
+ top_k=args.top_k,
174
+ repetition_penalty=args.repetition_penalty,
175
+ device=device,
176
+ stop_sequences=stop_sequences,
177
+ ):
178
+ print(chunk, end="", flush=True)
179
+ generated_text_parts.append(chunk)
180
+
181
+ generated_text = "".join(generated_text_parts)
182
+
183
+ history += (
184
+ f"<|im_start|>{args.user_role}\n{user_input}<|im_end|>\n"
185
+ + f"<|im_start|>{args.assistant_role}\n{generated_text}<|im_end|>\n"
186
+ )
187
+ print() # Newline after assistant output
188
+
189
+ except (KeyboardInterrupt, EOFError):
190
+ print("\nExiting chat.")
191
+ break
192
+ else:
193
+ print(f"\nGenerating text with prompt: '{args.prompt}'")
194
+ print(
195
+ f"Parameters: temp={args.temperature}, top_k={args.top_k}, repetition_penalty={args.repetition_penalty}"
196
+ )
197
+ print("\n--- Generation Start ---")
198
+
199
+ generated_text_parts = []
200
+ for chunk in generate_text_stream(
201
+ model=model,
202
+ tokenizer=tokenizer,
203
+ prompt=args.prompt,
204
+ max_new_tokens=args.max_tokens,
205
+ temperature=args.temperature,
206
+ top_k=args.top_k,
207
+ repetition_penalty=args.repetition_penalty,
208
+ device=device,
209
+ stop_sequences=args.stop,
210
+ ):
211
+ print(chunk, end="", flush=True)
212
+ generated_text_parts.append(chunk)
213
+
214
+ print("\n--- Generation End ---")
215
+
216
+ generated_text = "".join(generated_text_parts)
217
+ full_text = args.prompt + generated_text
218
+
219
+ print("\n\nFull generated text (for reference):")
220
+ print("-" * 40)
221
+ print(full_text)
222
+ print("-" * 40)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ main()