Luna-150M / generate_text.py
JMSykala's picture
Update generate_text.py
d19cc7f verified
# Copyright 2026 Jakub Sykała
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import argparse
import re
from typing import Dict, List
from collections import OrderedDict
import torch
import torch.nn.functional as F
import pyphen
from model import LunaConfig, Luna
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Tokenizer
class LunaTokenizer:
"""Tokenizer for Luna."""
VOWELS = set('aeiouyAEIOUY')
TYPE_SYLLABLE = 0
TYPE_NUMBER = 1
TYPE_PUNCT = 2
TYPE_SPECIAL = 3
def __init__(self):
self.hyphenator = pyphen.Pyphen(lang='en_US')
self.syllable_to_id = {}
self.id_to_syllable = {}
self.onset_to_id = {}
self.nucleus_to_id = {}
self.coda_to_id = {}
self.unk_syllable = 1
self.unk_onset = 1
self.unk_nucleus = 1
self.unk_coda = 1
def load_vocab(self, vocab_path: str):
with open(vocab_path) as f:
vocab = json.load(f)
self.syllable_to_id = vocab.get('syllable_to_id', {})
self.id_to_syllable = {int(k): v for k, v in vocab.get('id_to_syllable', {}).items()}
self.onset_to_id = vocab.get('onset_to_id', {})
self.nucleus_to_id = vocab.get('nucleus_to_id', {})
self.coda_to_id = vocab.get('coda_to_id', {})
if not self.id_to_syllable:
self.id_to_syllable = {int(v): k for k, v in self.syllable_to_id.items()}
self.unk_syllable = self.syllable_to_id.get('<unk>', 1)
self.unk_onset = self.onset_to_id.get('', 1)
self.unk_nucleus = self.nucleus_to_id.get('', 1)
self.unk_coda = self.coda_to_id.get('', 1)
def _get_id(self, vocab, item, unk_id):
return vocab.get(item, unk_id)
def _split_onset_nucleus_coda(self, syllable: str):
syl = syllable.lower()
if not syl:
return ('', '', '')
nucleus_start = -1
nucleus_end = -1
for i, char in enumerate(syl):
if char in self.VOWELS:
if nucleus_start == -1:
nucleus_start = i
nucleus_end = i + 1
elif nucleus_start != -1:
break
if nucleus_start == -1:
return (syl, '', '')
return (syl[:nucleus_start], syl[nucleus_start:nucleus_end], syl[nucleus_end:])
def _detect_token_type(self, text: str) -> int:
text = text.strip()
if not text:
return self.TYPE_SYLLABLE
if re.match(r'^-?\d+\.?\d*$', text):
return self.TYPE_NUMBER
if all(c in '.,!?;:\'"()-[]{}/<>@#$%^&*+=|\\`~' for c in text):
return self.TYPE_PUNCT
return self.TYPE_SYLLABLE
def _syllabify_word(self, word: str) -> List[str]:
clean = ''.join(c for c in word if c.isalpha())
if not clean:
return [word]
hyphenated = self.hyphenator.inserted(clean)
return hyphenated.split('-')
def encode(self, text: str) -> List[Dict]:
parts = re.findall(r"[\w']+|[.,!?;:\"'\-\(\)\[\]{}/<>@#$%^&*+=|\\`~]|\s+", text)
tokens = []
for part in parts:
if part.isspace():
if tokens:
tokens[-1]['has_space_after'] = 1
continue
token_type = self._detect_token_type(part)
if token_type == self.TYPE_NUMBER:
for i, char in enumerate(part):
syl_key = f'<num_{char}>'
tokens.append({
'syllable_id': self._get_id(self.syllable_to_id, syl_key, self.unk_syllable),
'onset_id': self._get_id(self.onset_to_id, '<num>', self.unk_onset),
'nucleus_id': self._get_id(self.nucleus_to_id, char, self.unk_nucleus),
'coda_id': self._get_id(self.coda_to_id, '', self.unk_coda),
'position': 3 if len(part) == 1 else (1 if i == 0 else (2 if i == len(part) - 1 else 0)),
'is_capitalized': 0,
'token_type': self.TYPE_NUMBER,
'has_space_after': 0,
'is_word_end': 1 if i == len(part) - 1 else 0,
})
continue
if token_type == self.TYPE_PUNCT:
syl_key = f'<punct_{part}>'
tokens.append({
'syllable_id': self._get_id(self.syllable_to_id, syl_key, self.unk_syllable),
'onset_id': self._get_id(self.onset_to_id, '<punct>', self.unk_onset),
'nucleus_id': self._get_id(self.nucleus_to_id, part, self.unk_nucleus),
'coda_id': self._get_id(self.coda_to_id, '', self.unk_coda),
'position': 3,
'is_capitalized': 0,
'token_type': self.TYPE_PUNCT,
'has_space_after': 0,
'is_word_end': 1,
})
continue
syllables = self._syllabify_word(part)
is_cap = part[0].isupper() if part else False
for i, syl in enumerate(syllables):
onset, nucleus, coda = self._split_onset_nucleus_coda(syl)
pos = 0
if len(syllables) == 1:
pos = 3
elif i == 0:
pos = 1
elif i == len(syllables) - 1:
pos = 2
syl_lower = syl.lower()
tokens.append({
'syllable_id': self._get_id(self.syllable_to_id, syl_lower, self.unk_syllable),
'onset_id': self._get_id(self.onset_to_id, onset, self.unk_onset),
'nucleus_id': self._get_id(self.nucleus_to_id, nucleus, self.unk_nucleus),
'coda_id': self._get_id(self.coda_to_id, coda, self.unk_coda),
'position': pos,
'is_capitalized': 1 if is_cap and i == 0 else 0,
'token_type': self.TYPE_SYLLABLE,
'has_space_after': 0,
'is_word_end': 1 if i == len(syllables) - 1 else 0,
})
return tokens
def decode_syllable_id(self, sid: int) -> str:
syl = self.id_to_syllable.get(sid, '<unk>')
if syl.startswith('<punct_') and syl.endswith('>'):
return syl[7:-1]
if syl.startswith('<num_') and syl.endswith('>'):
return syl[5:-1]
if syl.startswith('<char_') and syl.endswith('>'):
return syl[6:-1]
if syl in ('<pad>', '<unk>'):
return ''
return syl
# Helpers
def tokens_to_tensor(tokens: List[Dict], device) -> torch.Tensor:
feature_names = [
'syllable_id', 'onset_id', 'nucleus_id', 'coda_id',
'position', 'is_capitalized', 'token_type', 'has_space_after', 'is_word_end'
]
features = [[token.get(name, 0) for name in feature_names] for token in tokens]
return torch.tensor(features, dtype=torch.long, device=device).unsqueeze(0)
def decode_tokens(tokenizer: LunaTokenizer, tokens: List[Dict]) -> str:
parts = []
current_word = []
for token in tokens:
syl_id = token.get('syllable_id', 0)
space = token.get('has_space_after', 0)
cap = token.get('is_capitalized', 0)
position = token.get('position', 0)
is_word_end = token.get('is_word_end', 0)
token_type = token.get('token_type', 0)
text = tokenizer.decode_syllable_id(syl_id)
if not text:
continue
if token_type == 2 and parts and parts[-1].strip() == text:
continue
if cap and text and (not current_word) and text[0].isalpha():
text = text[0].upper() + text[1:] if len(text) > 1 else text.upper()
current_word.append(text)
word_ends = (space == 1 or is_word_end == 1 or position in [2, 3] or token_type == 2)
if word_ends:
word = ''.join(current_word)
if word in '.,!?;:\'"' and parts and parts[-1] == ' ':
parts.pop()
parts.append(word)
if space == 1 and word not in '(\'"[{':
parts.append(' ')
current_word = []
if current_word:
parts.append(''.join(current_word))
result = ''.join(parts)
while ' ' in result:
result = result.replace(' ', ' ')
for punct in '.,!?;:\'"':
result = result.replace(f' {punct}', punct)
return result
# Model Loading
def load_model(checkpoint_path: str, data_dir: str):
vocab_path = os.path.join(data_dir, "vocab.json")
tokenizer = LunaTokenizer()
tokenizer.load_vocab(vocab_path)
if os.path.isdir(checkpoint_path):
for name in ["model_best.pt", "model_final.pt", "checkpoint_latest.pt"]:
ckpt = os.path.join(checkpoint_path, name)
if os.path.exists(ckpt):
checkpoint_path = ckpt
break
print(f"Loading: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
config = checkpoint.get('config', LunaConfig())
model = Luna(config)
state_dict = checkpoint.get('model', checkpoint.get('model_state_dict'))
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[10:] if k.startswith('_orig_mod.') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.to(DEVICE)
model.eval()
return model, tokenizer, checkpoint.get('val_loss', 0)
# Generation
@torch.no_grad()
def generate(
model: Luna,
tokenizer: LunaTokenizer,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.8,
top_k: int = 40,
top_p: float = 0.9,
repetition_penalty: float = 1.2,
) -> str:
model.eval()
tokens = tokenizer.encode(prompt)
if not tokens:
tokens = [{'syllable_id': 0, 'onset_id': 0, 'nucleus_id': 0, 'coda_id': 0,
'position': 0, 'is_capitalized': 0, 'token_type': 0,
'has_space_after': 0, 'is_word_end': 0}]
prompt_len = len(tokens)
recent_tokens = []
recent_texts = []
max_len = model.config.max_seq_len
pad_id = tokenizer.syllable_to_id.get('<pad>', 0)
unk_id = tokenizer.syllable_to_id.get('<unk>', 1)
bad_single_chars = {sid for syl, sid in tokenizer.syllable_to_id.items()
if len(syl) == 1 and syl.isalpha()}
for _ in range(max_new_tokens):
if len(tokens) > max_len:
tokens = tokens[-max_len:]
input_tensor = tokens_to_tensor(tokens, DEVICE)
with torch.autocast(device_type='cuda' if DEVICE.type == 'cuda' else 'cpu', dtype=torch.bfloat16):
logits, _ = model(input_tensor)
syl_logits = logits['syllable'][0, -1, :].float() / temperature
syl_logits[pad_id] = float('-inf')
syl_logits[unk_id] = float('-inf')
for bad_id in bad_single_chars:
syl_logits[bad_id] = float('-inf')
if top_k > 0:
top_k_val = min(top_k, syl_logits.size(-1))
values, _ = torch.topk(syl_logits, top_k_val)
syl_logits[syl_logits < values[-1]] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(syl_logits, descending=True)
cumsum = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
mask = cumsum > top_p
mask[1:] = mask[:-1].clone()
mask[0] = False
sorted_logits[mask] = float('-inf')
syl_logits = torch.zeros_like(syl_logits).scatter_(-1, sorted_idx, sorted_logits)
if repetition_penalty > 1.0:
for tid in set(recent_tokens[-30:]):
if 0 <= tid < syl_logits.size(-1):
if syl_logits[tid] > 0:
syl_logits[tid] /= repetition_penalty
else:
syl_logits[tid] *= repetition_penalty
probs = F.softmax(syl_logits, dim=-1)
if torch.isinf(syl_logits).all():
break
next_syl_id = torch.multinomial(probs, 1).item()
syl_text = tokenizer.decode_syllable_id(next_syl_id)
next_pos = logits['position'][0, -1, :].argmax().item()
next_cap = logits['is_capitalized'][0, -1, :].argmax().item()
next_type = logits['token_type'][0, -1, :].argmax().item()
space_probs = torch.softmax(logits['has_space_after'][0, -1, :], dim=-1)
next_space = 1 if space_probs[1] > 0.25 else 0
#next_space = logits['has_space_after'][0, -1, :].argmax().item()
onset, nucleus, coda = tokenizer._split_onset_nucleus_coda(syl_text)
next_onset = tokenizer._get_id(tokenizer.onset_to_id, onset, tokenizer.unk_onset)
next_nucleus = tokenizer._get_id(tokenizer.nucleus_to_id, nucleus, tokenizer.unk_nucleus)
next_coda = tokenizer._get_id(tokenizer.coda_to_id, coda, tokenizer.unk_coda)
next_word_end = 1 if (next_pos in [2, 3] or next_space == 1) else 0
recent_tokens.append(next_syl_id)
recent_texts.append(syl_text)
tokens.append({
'syllable_id': next_syl_id,
'onset_id': next_onset,
'nucleus_id': next_nucleus,
'coda_id': next_coda,
'position': next_pos,
'is_capitalized': next_cap,
'token_type': next_type,
'has_space_after': next_space,
'is_word_end': next_word_end,
})
if len(recent_texts) >= 4 and len(set(recent_texts[-4:])) == 1:
break
if len(recent_texts) >= 6:
last_6 = recent_texts[-6:]
if last_6[0] == last_6[2] == last_6[4] and last_6[1] == last_6[3] == last_6[5]:
break
generated_tokens = tokens[prompt_len:]
prompt_text = decode_tokens(tokenizer, tokens[:prompt_len])
generated_text = decode_tokens(tokenizer, generated_tokens)
return prompt_text + generated_text
# Interactive Mode
def interactive_mode(model, tokenizer, args):
print("\n" + "=" * 60)
print("Interactive Mode (type 'quit' to exit)")
print("=" * 60)
while True:
try:
prompt = input("\nPrompt: ").strip()
if prompt.lower() in ('quit', 'exit', 'q'):
break
if not prompt:
continue
output = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
)
print(f"\n{output}")
except KeyboardInterrupt:
break
print("\nGoodbye!")
# Main
def main():
parser = argparse.ArgumentParser(description="Generate text with Luna")
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--max_tokens", type=int, default=100)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_k", type=int, default=40)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--repetition_penalty", type=float, default=1.0)
parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--seed", type=int, default=None)
args = parser.parse_args()
if args.seed is not None:
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
print("=" * 60)
print("Luna - Text Generation")
print("=" * 60)
print(f"Device: {DEVICE}")
model, tokenizer, val_loss = load_model(args.checkpoint, args.data_dir)
if val_loss:
print(f"Val loss: {val_loss:.4f}")
if args.prompt is None:
interactive_mode(model, tokenizer, args)
return
print(f"\nPrompt: '{args.prompt}'")
print(f"Settings: temp={args.temperature}, top_k={args.top_k}, top_p={args.top_p}")
print("-" * 60)
for i in range(args.num_samples):
if args.num_samples > 1:
print(f"\n--- Sample {i+1} ---")
output = generate(
model=model,
tokenizer=tokenizer,
prompt=args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
)
print(f"\n{output}")
print("\n" + "=" * 60)
if __name__ == "__main__":
main()