Update generate_text.py
Browse files- generate_text.py +8 -12
generate_text.py
CHANGED
|
@@ -28,9 +28,8 @@ from model import LunaConfig, Luna
|
|
| 28 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
|
| 30 |
|
| 31 |
-
# ==============================================================================
|
| 32 |
# Tokenizer
|
| 33 |
-
|
| 34 |
|
| 35 |
class LunaTokenizer:
|
| 36 |
"""Tokenizer for Luna."""
|
|
@@ -199,9 +198,8 @@ class LunaTokenizer:
|
|
| 199 |
return syl
|
| 200 |
|
| 201 |
|
| 202 |
-
|
| 203 |
# Helpers
|
| 204 |
-
# ==============================================================================
|
| 205 |
|
| 206 |
def tokens_to_tensor(tokens: List[Dict], device) -> torch.Tensor:
|
| 207 |
feature_names = [
|
|
@@ -259,9 +257,9 @@ def decode_tokens(tokenizer: LunaTokenizer, tokens: List[Dict]) -> str:
|
|
| 259 |
return result
|
| 260 |
|
| 261 |
|
| 262 |
-
|
| 263 |
# Model Loading
|
| 264 |
-
|
| 265 |
|
| 266 |
def load_model(checkpoint_path: str, data_dir: str):
|
| 267 |
vocab_path = os.path.join(data_dir, "vocab.json")
|
|
@@ -294,9 +292,8 @@ def load_model(checkpoint_path: str, data_dir: str):
|
|
| 294 |
return model, tokenizer, checkpoint.get('val_loss', 0)
|
| 295 |
|
| 296 |
|
| 297 |
-
# ==============================================================================
|
| 298 |
# Generation
|
| 299 |
-
|
| 300 |
|
| 301 |
@torch.no_grad()
|
| 302 |
def generate(
|
|
@@ -417,9 +414,9 @@ def generate(
|
|
| 417 |
return prompt_text + generated_text
|
| 418 |
|
| 419 |
|
| 420 |
-
|
| 421 |
# Interactive Mode
|
| 422 |
-
|
| 423 |
|
| 424 |
def interactive_mode(model, tokenizer, args):
|
| 425 |
print("\n" + "=" * 60)
|
|
@@ -451,9 +448,8 @@ def interactive_mode(model, tokenizer, args):
|
|
| 451 |
print("\nGoodbye!")
|
| 452 |
|
| 453 |
|
| 454 |
-
# ==============================================================================
|
| 455 |
# Main
|
| 456 |
-
|
| 457 |
|
| 458 |
def main():
|
| 459 |
parser = argparse.ArgumentParser(description="Generate text with Luna")
|
|
|
|
| 28 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
|
| 30 |
|
|
|
|
| 31 |
# Tokenizer
|
| 32 |
+
|
| 33 |
|
| 34 |
class LunaTokenizer:
|
| 35 |
"""Tokenizer for Luna."""
|
|
|
|
| 198 |
return syl
|
| 199 |
|
| 200 |
|
| 201 |
+
|
| 202 |
# Helpers
|
|
|
|
| 203 |
|
| 204 |
def tokens_to_tensor(tokens: List[Dict], device) -> torch.Tensor:
|
| 205 |
feature_names = [
|
|
|
|
| 257 |
return result
|
| 258 |
|
| 259 |
|
| 260 |
+
|
| 261 |
# Model Loading
|
| 262 |
+
|
| 263 |
|
| 264 |
def load_model(checkpoint_path: str, data_dir: str):
|
| 265 |
vocab_path = os.path.join(data_dir, "vocab.json")
|
|
|
|
| 292 |
return model, tokenizer, checkpoint.get('val_loss', 0)
|
| 293 |
|
| 294 |
|
|
|
|
| 295 |
# Generation
|
| 296 |
+
|
| 297 |
|
| 298 |
@torch.no_grad()
|
| 299 |
def generate(
|
|
|
|
| 414 |
return prompt_text + generated_text
|
| 415 |
|
| 416 |
|
| 417 |
+
|
| 418 |
# Interactive Mode
|
| 419 |
+
|
| 420 |
|
| 421 |
def interactive_mode(model, tokenizer, args):
|
| 422 |
print("\n" + "=" * 60)
|
|
|
|
| 448 |
print("\nGoodbye!")
|
| 449 |
|
| 450 |
|
|
|
|
| 451 |
# Main
|
| 452 |
+
|
| 453 |
|
| 454 |
def main():
|
| 455 |
parser = argparse.ArgumentParser(description="Generate text with Luna")
|