# Copyright 2026 Jakub SykaƂa # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import json import argparse import re from typing import Dict, List from collections import OrderedDict import torch import torch.nn.functional as F import pyphen from model import LunaConfig, Luna DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Tokenizer class LunaTokenizer: """Tokenizer for Luna.""" VOWELS = set('aeiouyAEIOUY') TYPE_SYLLABLE = 0 TYPE_NUMBER = 1 TYPE_PUNCT = 2 TYPE_SPECIAL = 3 def __init__(self): self.hyphenator = pyphen.Pyphen(lang='en_US') self.syllable_to_id = {} self.id_to_syllable = {} self.onset_to_id = {} self.nucleus_to_id = {} self.coda_to_id = {} self.unk_syllable = 1 self.unk_onset = 1 self.unk_nucleus = 1 self.unk_coda = 1 def load_vocab(self, vocab_path: str): with open(vocab_path) as f: vocab = json.load(f) self.syllable_to_id = vocab.get('syllable_to_id', {}) self.id_to_syllable = {int(k): v for k, v in vocab.get('id_to_syllable', {}).items()} self.onset_to_id = vocab.get('onset_to_id', {}) self.nucleus_to_id = vocab.get('nucleus_to_id', {}) self.coda_to_id = vocab.get('coda_to_id', {}) if not self.id_to_syllable: self.id_to_syllable = {int(v): k for k, v in self.syllable_to_id.items()} self.unk_syllable = self.syllable_to_id.get('', 1) self.unk_onset = self.onset_to_id.get('', 1) self.unk_nucleus = self.nucleus_to_id.get('', 1) self.unk_coda = self.coda_to_id.get('', 1) def _get_id(self, vocab, item, unk_id): return vocab.get(item, unk_id) def _split_onset_nucleus_coda(self, syllable: str): syl = syllable.lower() if not syl: return ('', '', '') nucleus_start = -1 nucleus_end = -1 for i, char in enumerate(syl): if char in self.VOWELS: if nucleus_start == -1: nucleus_start = i nucleus_end = i + 1 elif nucleus_start != -1: break if nucleus_start == -1: return (syl, '', '') return (syl[:nucleus_start], syl[nucleus_start:nucleus_end], syl[nucleus_end:]) def _detect_token_type(self, text: str) -> int: text = text.strip() if not text: return self.TYPE_SYLLABLE if re.match(r'^-?\d+\.?\d*$', text): return self.TYPE_NUMBER if all(c in '.,!?;:\'"()-[]{}/<>@#$%^&*+=|\\`~' for c in text): return self.TYPE_PUNCT return self.TYPE_SYLLABLE def _syllabify_word(self, word: str) -> List[str]: clean = ''.join(c for c in word if c.isalpha()) if not clean: return [word] hyphenated = self.hyphenator.inserted(clean) return hyphenated.split('-') def encode(self, text: str) -> List[Dict]: parts = re.findall(r"[\w']+|[.,!?;:\"'\-\(\)\[\]{}/<>@#$%^&*+=|\\`~]|\s+", text) tokens = [] for part in parts: if part.isspace(): if tokens: tokens[-1]['has_space_after'] = 1 continue token_type = self._detect_token_type(part) if token_type == self.TYPE_NUMBER: for i, char in enumerate(part): syl_key = f'' tokens.append({ 'syllable_id': self._get_id(self.syllable_to_id, syl_key, self.unk_syllable), 'onset_id': self._get_id(self.onset_to_id, '', self.unk_onset), 'nucleus_id': self._get_id(self.nucleus_to_id, char, self.unk_nucleus), 'coda_id': self._get_id(self.coda_to_id, '', self.unk_coda), 'position': 3 if len(part) == 1 else (1 if i == 0 else (2 if i == len(part) - 1 else 0)), 'is_capitalized': 0, 'token_type': self.TYPE_NUMBER, 'has_space_after': 0, 'is_word_end': 1 if i == len(part) - 1 else 0, }) continue if token_type == self.TYPE_PUNCT: syl_key = f'' tokens.append({ 'syllable_id': self._get_id(self.syllable_to_id, syl_key, self.unk_syllable), 'onset_id': self._get_id(self.onset_to_id, '', self.unk_onset), 'nucleus_id': self._get_id(self.nucleus_to_id, part, self.unk_nucleus), 'coda_id': self._get_id(self.coda_to_id, '', self.unk_coda), 'position': 3, 'is_capitalized': 0, 'token_type': self.TYPE_PUNCT, 'has_space_after': 0, 'is_word_end': 1, }) continue syllables = self._syllabify_word(part) is_cap = part[0].isupper() if part else False for i, syl in enumerate(syllables): onset, nucleus, coda = self._split_onset_nucleus_coda(syl) pos = 0 if len(syllables) == 1: pos = 3 elif i == 0: pos = 1 elif i == len(syllables) - 1: pos = 2 syl_lower = syl.lower() tokens.append({ 'syllable_id': self._get_id(self.syllable_to_id, syl_lower, self.unk_syllable), 'onset_id': self._get_id(self.onset_to_id, onset, self.unk_onset), 'nucleus_id': self._get_id(self.nucleus_to_id, nucleus, self.unk_nucleus), 'coda_id': self._get_id(self.coda_to_id, coda, self.unk_coda), 'position': pos, 'is_capitalized': 1 if is_cap and i == 0 else 0, 'token_type': self.TYPE_SYLLABLE, 'has_space_after': 0, 'is_word_end': 1 if i == len(syllables) - 1 else 0, }) return tokens def decode_syllable_id(self, sid: int) -> str: syl = self.id_to_syllable.get(sid, '') if syl.startswith(''): return syl[7:-1] if syl.startswith(''): return syl[5:-1] if syl.startswith(''): return syl[6:-1] if syl in ('', ''): return '' return syl # Helpers def tokens_to_tensor(tokens: List[Dict], device) -> torch.Tensor: feature_names = [ 'syllable_id', 'onset_id', 'nucleus_id', 'coda_id', 'position', 'is_capitalized', 'token_type', 'has_space_after', 'is_word_end' ] features = [[token.get(name, 0) for name in feature_names] for token in tokens] return torch.tensor(features, dtype=torch.long, device=device).unsqueeze(0) def decode_tokens(tokenizer: LunaTokenizer, tokens: List[Dict]) -> str: parts = [] current_word = [] for token in tokens: syl_id = token.get('syllable_id', 0) space = token.get('has_space_after', 0) cap = token.get('is_capitalized', 0) position = token.get('position', 0) is_word_end = token.get('is_word_end', 0) token_type = token.get('token_type', 0) text = tokenizer.decode_syllable_id(syl_id) if not text: continue if token_type == 2 and parts and parts[-1].strip() == text: continue if cap and text and (not current_word) and text[0].isalpha(): text = text[0].upper() + text[1:] if len(text) > 1 else text.upper() current_word.append(text) word_ends = (space == 1 or is_word_end == 1 or position in [2, 3] or token_type == 2) if word_ends: word = ''.join(current_word) if word in '.,!?;:\'"' and parts and parts[-1] == ' ': parts.pop() parts.append(word) if space == 1 and word not in '(\'"[{': parts.append(' ') current_word = [] if current_word: parts.append(''.join(current_word)) result = ''.join(parts) while ' ' in result: result = result.replace(' ', ' ') for punct in '.,!?;:\'"': result = result.replace(f' {punct}', punct) return result # Model Loading def load_model(checkpoint_path: str, data_dir: str): vocab_path = os.path.join(data_dir, "vocab.json") tokenizer = LunaTokenizer() tokenizer.load_vocab(vocab_path) if os.path.isdir(checkpoint_path): for name in ["model_best.pt", "model_final.pt", "checkpoint_latest.pt"]: ckpt = os.path.join(checkpoint_path, name) if os.path.exists(ckpt): checkpoint_path = ckpt break print(f"Loading: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False) config = checkpoint.get('config', LunaConfig()) model = Luna(config) state_dict = checkpoint.get('model', checkpoint.get('model_state_dict')) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[10:] if k.startswith('_orig_mod.') else k new_state_dict[name] = v model.load_state_dict(new_state_dict) model.to(DEVICE) model.eval() return model, tokenizer, checkpoint.get('val_loss', 0) # Generation @torch.no_grad() def generate( model: Luna, tokenizer: LunaTokenizer, prompt: str, max_new_tokens: int = 100, temperature: float = 0.8, top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.2, ) -> str: model.eval() tokens = tokenizer.encode(prompt) if not tokens: tokens = [{'syllable_id': 0, 'onset_id': 0, 'nucleus_id': 0, 'coda_id': 0, 'position': 0, 'is_capitalized': 0, 'token_type': 0, 'has_space_after': 0, 'is_word_end': 0}] prompt_len = len(tokens) recent_tokens = [] recent_texts = [] max_len = model.config.max_seq_len pad_id = tokenizer.syllable_to_id.get('', 0) unk_id = tokenizer.syllable_to_id.get('', 1) bad_single_chars = {sid for syl, sid in tokenizer.syllable_to_id.items() if len(syl) == 1 and syl.isalpha()} for _ in range(max_new_tokens): if len(tokens) > max_len: tokens = tokens[-max_len:] input_tensor = tokens_to_tensor(tokens, DEVICE) with torch.autocast(device_type='cuda' if DEVICE.type == 'cuda' else 'cpu', dtype=torch.bfloat16): logits, _ = model(input_tensor) syl_logits = logits['syllable'][0, -1, :].float() / temperature syl_logits[pad_id] = float('-inf') syl_logits[unk_id] = float('-inf') for bad_id in bad_single_chars: syl_logits[bad_id] = float('-inf') if top_k > 0: top_k_val = min(top_k, syl_logits.size(-1)) values, _ = torch.topk(syl_logits, top_k_val) syl_logits[syl_logits < values[-1]] = float('-inf') if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(syl_logits, descending=True) cumsum = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) mask = cumsum > top_p mask[1:] = mask[:-1].clone() mask[0] = False sorted_logits[mask] = float('-inf') syl_logits = torch.zeros_like(syl_logits).scatter_(-1, sorted_idx, sorted_logits) if repetition_penalty > 1.0: for tid in set(recent_tokens[-30:]): if 0 <= tid < syl_logits.size(-1): if syl_logits[tid] > 0: syl_logits[tid] /= repetition_penalty else: syl_logits[tid] *= repetition_penalty probs = F.softmax(syl_logits, dim=-1) if torch.isinf(syl_logits).all(): break next_syl_id = torch.multinomial(probs, 1).item() syl_text = tokenizer.decode_syllable_id(next_syl_id) next_pos = logits['position'][0, -1, :].argmax().item() next_cap = logits['is_capitalized'][0, -1, :].argmax().item() next_type = logits['token_type'][0, -1, :].argmax().item() space_probs = torch.softmax(logits['has_space_after'][0, -1, :], dim=-1) next_space = 1 if space_probs[1] > 0.25 else 0 #next_space = logits['has_space_after'][0, -1, :].argmax().item() onset, nucleus, coda = tokenizer._split_onset_nucleus_coda(syl_text) next_onset = tokenizer._get_id(tokenizer.onset_to_id, onset, tokenizer.unk_onset) next_nucleus = tokenizer._get_id(tokenizer.nucleus_to_id, nucleus, tokenizer.unk_nucleus) next_coda = tokenizer._get_id(tokenizer.coda_to_id, coda, tokenizer.unk_coda) next_word_end = 1 if (next_pos in [2, 3] or next_space == 1) else 0 recent_tokens.append(next_syl_id) recent_texts.append(syl_text) tokens.append({ 'syllable_id': next_syl_id, 'onset_id': next_onset, 'nucleus_id': next_nucleus, 'coda_id': next_coda, 'position': next_pos, 'is_capitalized': next_cap, 'token_type': next_type, 'has_space_after': next_space, 'is_word_end': next_word_end, }) if len(recent_texts) >= 4 and len(set(recent_texts[-4:])) == 1: break if len(recent_texts) >= 6: last_6 = recent_texts[-6:] if last_6[0] == last_6[2] == last_6[4] and last_6[1] == last_6[3] == last_6[5]: break generated_tokens = tokens[prompt_len:] prompt_text = decode_tokens(tokenizer, tokens[:prompt_len]) generated_text = decode_tokens(tokenizer, generated_tokens) return prompt_text + generated_text # Interactive Mode def interactive_mode(model, tokenizer, args): print("\n" + "=" * 60) print("Interactive Mode (type 'quit' to exit)") print("=" * 60) while True: try: prompt = input("\nPrompt: ").strip() if prompt.lower() in ('quit', 'exit', 'q'): break if not prompt: continue output = generate( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, ) print(f"\n{output}") except KeyboardInterrupt: break print("\nGoodbye!") # Main def main(): parser = argparse.ArgumentParser(description="Generate text with Luna") parser.add_argument("--checkpoint", type=str, required=True) parser.add_argument("--data_dir", type=str, required=True) parser.add_argument("--prompt", type=str, default=None) parser.add_argument("--max_tokens", type=int, default=100) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top_k", type=int, default=40) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--repetition_penalty", type=float, default=1.0) parser.add_argument("--num_samples", type=int, default=1) parser.add_argument("--seed", type=int, default=None) args = parser.parse_args() if args.seed is not None: torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) print("=" * 60) print("Luna - Text Generation") print("=" * 60) print(f"Device: {DEVICE}") model, tokenizer, val_loss = load_model(args.checkpoint, args.data_dir) if val_loss: print(f"Val loss: {val_loss:.4f}") if args.prompt is None: interactive_mode(model, tokenizer, args) return print(f"\nPrompt: '{args.prompt}'") print(f"Settings: temp={args.temperature}, top_k={args.top_k}, top_p={args.top_p}") print("-" * 60) for i in range(args.num_samples): if args.num_samples > 1: print(f"\n--- Sample {i+1} ---") output = generate( model=model, tokenizer=tokenizer, prompt=args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, ) print(f"\n{output}") print("\n" + "=" * 60) if __name__ == "__main__": main()