JMSykala commited on
Commit
d19cc7f
·
verified ·
1 Parent(s): 9f1dbe9

Update generate_text.py

Browse files
Files changed (1) hide show
  1. generate_text.py +8 -12
generate_text.py CHANGED
@@ -28,9 +28,8 @@ from model import LunaConfig, Luna
28
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
 
31
- # ==============================================================================
32
  # Tokenizer
33
- # ==============================================================================
34
 
35
  class LunaTokenizer:
36
  """Tokenizer for Luna."""
@@ -199,9 +198,8 @@ class LunaTokenizer:
199
  return syl
200
 
201
 
202
- # ==============================================================================
203
  # Helpers
204
- # ==============================================================================
205
 
206
  def tokens_to_tensor(tokens: List[Dict], device) -> torch.Tensor:
207
  feature_names = [
@@ -259,9 +257,9 @@ def decode_tokens(tokenizer: LunaTokenizer, tokens: List[Dict]) -> str:
259
  return result
260
 
261
 
262
- # ==============================================================================
263
  # Model Loading
264
- # ==============================================================================
265
 
266
  def load_model(checkpoint_path: str, data_dir: str):
267
  vocab_path = os.path.join(data_dir, "vocab.json")
@@ -294,9 +292,8 @@ def load_model(checkpoint_path: str, data_dir: str):
294
  return model, tokenizer, checkpoint.get('val_loss', 0)
295
 
296
 
297
- # ==============================================================================
298
  # Generation
299
- # ==============================================================================
300
 
301
  @torch.no_grad()
302
  def generate(
@@ -417,9 +414,9 @@ def generate(
417
  return prompt_text + generated_text
418
 
419
 
420
- # ==============================================================================
421
  # Interactive Mode
422
- # ==============================================================================
423
 
424
  def interactive_mode(model, tokenizer, args):
425
  print("\n" + "=" * 60)
@@ -451,9 +448,8 @@ def interactive_mode(model, tokenizer, args):
451
  print("\nGoodbye!")
452
 
453
 
454
- # ==============================================================================
455
  # Main
456
- # ==============================================================================
457
 
458
  def main():
459
  parser = argparse.ArgumentParser(description="Generate text with Luna")
 
28
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
 
 
31
  # Tokenizer
32
+
33
 
34
  class LunaTokenizer:
35
  """Tokenizer for Luna."""
 
198
  return syl
199
 
200
 
201
+
202
  # Helpers
 
203
 
204
  def tokens_to_tensor(tokens: List[Dict], device) -> torch.Tensor:
205
  feature_names = [
 
257
  return result
258
 
259
 
260
+
261
  # Model Loading
262
+
263
 
264
  def load_model(checkpoint_path: str, data_dir: str):
265
  vocab_path = os.path.join(data_dir, "vocab.json")
 
292
  return model, tokenizer, checkpoint.get('val_loss', 0)
293
 
294
 
 
295
  # Generation
296
+
297
 
298
  @torch.no_grad()
299
  def generate(
 
414
  return prompt_text + generated_text
415
 
416
 
417
+
418
  # Interactive Mode
419
+
420
 
421
  def interactive_mode(model, tokenizer, args):
422
  print("\n" + "=" * 60)
 
448
  print("\nGoodbye!")
449
 
450
 
 
451
  # Main
452
+
453
 
454
  def main():
455
  parser = argparse.ArgumentParser(description="Generate text with Luna")