| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import json |
| | import argparse |
| | import re |
| | from typing import Dict, List |
| | from collections import OrderedDict |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import pyphen |
| |
|
| | from model import LunaConfig, Luna |
| |
|
| | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| | |
| |
|
| |
|
| | class LunaTokenizer: |
| | """Tokenizer for Luna.""" |
| | |
| | VOWELS = set('aeiouyAEIOUY') |
| | TYPE_SYLLABLE = 0 |
| | TYPE_NUMBER = 1 |
| | TYPE_PUNCT = 2 |
| | TYPE_SPECIAL = 3 |
| | |
| | def __init__(self): |
| | self.hyphenator = pyphen.Pyphen(lang='en_US') |
| | self.syllable_to_id = {} |
| | self.id_to_syllable = {} |
| | self.onset_to_id = {} |
| | self.nucleus_to_id = {} |
| | self.coda_to_id = {} |
| | |
| | self.unk_syllable = 1 |
| | self.unk_onset = 1 |
| | self.unk_nucleus = 1 |
| | self.unk_coda = 1 |
| | |
| | def load_vocab(self, vocab_path: str): |
| | with open(vocab_path) as f: |
| | vocab = json.load(f) |
| | |
| | self.syllable_to_id = vocab.get('syllable_to_id', {}) |
| | self.id_to_syllable = {int(k): v for k, v in vocab.get('id_to_syllable', {}).items()} |
| | self.onset_to_id = vocab.get('onset_to_id', {}) |
| | self.nucleus_to_id = vocab.get('nucleus_to_id', {}) |
| | self.coda_to_id = vocab.get('coda_to_id', {}) |
| | |
| | if not self.id_to_syllable: |
| | self.id_to_syllable = {int(v): k for k, v in self.syllable_to_id.items()} |
| | |
| | self.unk_syllable = self.syllable_to_id.get('<unk>', 1) |
| | self.unk_onset = self.onset_to_id.get('', 1) |
| | self.unk_nucleus = self.nucleus_to_id.get('', 1) |
| | self.unk_coda = self.coda_to_id.get('', 1) |
| | |
| | def _get_id(self, vocab, item, unk_id): |
| | return vocab.get(item, unk_id) |
| | |
| | def _split_onset_nucleus_coda(self, syllable: str): |
| | syl = syllable.lower() |
| | if not syl: |
| | return ('', '', '') |
| | |
| | nucleus_start = -1 |
| | nucleus_end = -1 |
| | |
| | for i, char in enumerate(syl): |
| | if char in self.VOWELS: |
| | if nucleus_start == -1: |
| | nucleus_start = i |
| | nucleus_end = i + 1 |
| | elif nucleus_start != -1: |
| | break |
| | |
| | if nucleus_start == -1: |
| | return (syl, '', '') |
| | |
| | return (syl[:nucleus_start], syl[nucleus_start:nucleus_end], syl[nucleus_end:]) |
| | |
| | def _detect_token_type(self, text: str) -> int: |
| | text = text.strip() |
| | if not text: |
| | return self.TYPE_SYLLABLE |
| | if re.match(r'^-?\d+\.?\d*$', text): |
| | return self.TYPE_NUMBER |
| | if all(c in '.,!?;:\'"()-[]{}/<>@#$%^&*+=|\\`~' for c in text): |
| | return self.TYPE_PUNCT |
| | return self.TYPE_SYLLABLE |
| | |
| | def _syllabify_word(self, word: str) -> List[str]: |
| | clean = ''.join(c for c in word if c.isalpha()) |
| | if not clean: |
| | return [word] |
| | hyphenated = self.hyphenator.inserted(clean) |
| | return hyphenated.split('-') |
| | |
| | def encode(self, text: str) -> List[Dict]: |
| | parts = re.findall(r"[\w']+|[.,!?;:\"'\-\(\)\[\]{}/<>@#$%^&*+=|\\`~]|\s+", text) |
| | tokens = [] |
| | |
| | for part in parts: |
| | if part.isspace(): |
| | if tokens: |
| | tokens[-1]['has_space_after'] = 1 |
| | continue |
| | |
| | token_type = self._detect_token_type(part) |
| | |
| | if token_type == self.TYPE_NUMBER: |
| | for i, char in enumerate(part): |
| | syl_key = f'<num_{char}>' |
| | tokens.append({ |
| | 'syllable_id': self._get_id(self.syllable_to_id, syl_key, self.unk_syllable), |
| | 'onset_id': self._get_id(self.onset_to_id, '<num>', self.unk_onset), |
| | 'nucleus_id': self._get_id(self.nucleus_to_id, char, self.unk_nucleus), |
| | 'coda_id': self._get_id(self.coda_to_id, '', self.unk_coda), |
| | 'position': 3 if len(part) == 1 else (1 if i == 0 else (2 if i == len(part) - 1 else 0)), |
| | 'is_capitalized': 0, |
| | 'token_type': self.TYPE_NUMBER, |
| | 'has_space_after': 0, |
| | 'is_word_end': 1 if i == len(part) - 1 else 0, |
| | }) |
| | continue |
| | |
| | if token_type == self.TYPE_PUNCT: |
| | syl_key = f'<punct_{part}>' |
| | tokens.append({ |
| | 'syllable_id': self._get_id(self.syllable_to_id, syl_key, self.unk_syllable), |
| | 'onset_id': self._get_id(self.onset_to_id, '<punct>', self.unk_onset), |
| | 'nucleus_id': self._get_id(self.nucleus_to_id, part, self.unk_nucleus), |
| | 'coda_id': self._get_id(self.coda_to_id, '', self.unk_coda), |
| | 'position': 3, |
| | 'is_capitalized': 0, |
| | 'token_type': self.TYPE_PUNCT, |
| | 'has_space_after': 0, |
| | 'is_word_end': 1, |
| | }) |
| | continue |
| | |
| | syllables = self._syllabify_word(part) |
| | is_cap = part[0].isupper() if part else False |
| | |
| | for i, syl in enumerate(syllables): |
| | onset, nucleus, coda = self._split_onset_nucleus_coda(syl) |
| | |
| | pos = 0 |
| | if len(syllables) == 1: |
| | pos = 3 |
| | elif i == 0: |
| | pos = 1 |
| | elif i == len(syllables) - 1: |
| | pos = 2 |
| | |
| | syl_lower = syl.lower() |
| | |
| | tokens.append({ |
| | 'syllable_id': self._get_id(self.syllable_to_id, syl_lower, self.unk_syllable), |
| | 'onset_id': self._get_id(self.onset_to_id, onset, self.unk_onset), |
| | 'nucleus_id': self._get_id(self.nucleus_to_id, nucleus, self.unk_nucleus), |
| | 'coda_id': self._get_id(self.coda_to_id, coda, self.unk_coda), |
| | 'position': pos, |
| | 'is_capitalized': 1 if is_cap and i == 0 else 0, |
| | 'token_type': self.TYPE_SYLLABLE, |
| | 'has_space_after': 0, |
| | 'is_word_end': 1 if i == len(syllables) - 1 else 0, |
| | }) |
| | |
| | return tokens |
| | |
| | def decode_syllable_id(self, sid: int) -> str: |
| | syl = self.id_to_syllable.get(sid, '<unk>') |
| | if syl.startswith('<punct_') and syl.endswith('>'): |
| | return syl[7:-1] |
| | if syl.startswith('<num_') and syl.endswith('>'): |
| | return syl[5:-1] |
| | if syl.startswith('<char_') and syl.endswith('>'): |
| | return syl[6:-1] |
| | if syl in ('<pad>', '<unk>'): |
| | return '' |
| | return syl |
| |
|
| |
|
| |
|
| | |
| |
|
| | def tokens_to_tensor(tokens: List[Dict], device) -> torch.Tensor: |
| | feature_names = [ |
| | 'syllable_id', 'onset_id', 'nucleus_id', 'coda_id', |
| | 'position', 'is_capitalized', 'token_type', 'has_space_after', 'is_word_end' |
| | ] |
| | features = [[token.get(name, 0) for name in feature_names] for token in tokens] |
| | return torch.tensor(features, dtype=torch.long, device=device).unsqueeze(0) |
| |
|
| |
|
| | def decode_tokens(tokenizer: LunaTokenizer, tokens: List[Dict]) -> str: |
| | parts = [] |
| | current_word = [] |
| | |
| | for token in tokens: |
| | syl_id = token.get('syllable_id', 0) |
| | space = token.get('has_space_after', 0) |
| | cap = token.get('is_capitalized', 0) |
| | position = token.get('position', 0) |
| | is_word_end = token.get('is_word_end', 0) |
| | token_type = token.get('token_type', 0) |
| | |
| | text = tokenizer.decode_syllable_id(syl_id) |
| | if not text: |
| | continue |
| | |
| | if token_type == 2 and parts and parts[-1].strip() == text: |
| | continue |
| | |
| | if cap and text and (not current_word) and text[0].isalpha(): |
| | text = text[0].upper() + text[1:] if len(text) > 1 else text.upper() |
| | |
| | current_word.append(text) |
| | |
| | word_ends = (space == 1 or is_word_end == 1 or position in [2, 3] or token_type == 2) |
| | |
| | if word_ends: |
| | word = ''.join(current_word) |
| | if word in '.,!?;:\'"' and parts and parts[-1] == ' ': |
| | parts.pop() |
| | parts.append(word) |
| | if space == 1 and word not in '(\'"[{': |
| | parts.append(' ') |
| | current_word = [] |
| | |
| | if current_word: |
| | parts.append(''.join(current_word)) |
| | |
| | result = ''.join(parts) |
| | while ' ' in result: |
| | result = result.replace(' ', ' ') |
| | for punct in '.,!?;:\'"': |
| | result = result.replace(f' {punct}', punct) |
| | |
| | return result |
| |
|
| |
|
| |
|
| | |
| |
|
| |
|
| | def load_model(checkpoint_path: str, data_dir: str): |
| | vocab_path = os.path.join(data_dir, "vocab.json") |
| | tokenizer = LunaTokenizer() |
| | tokenizer.load_vocab(vocab_path) |
| | |
| | if os.path.isdir(checkpoint_path): |
| | for name in ["model_best.pt", "model_final.pt", "checkpoint_latest.pt"]: |
| | ckpt = os.path.join(checkpoint_path, name) |
| | if os.path.exists(ckpt): |
| | checkpoint_path = ckpt |
| | break |
| | |
| | print(f"Loading: {checkpoint_path}") |
| | checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False) |
| | |
| | config = checkpoint.get('config', LunaConfig()) |
| | model = Luna(config) |
| | |
| | state_dict = checkpoint.get('model', checkpoint.get('model_state_dict')) |
| | new_state_dict = OrderedDict() |
| | for k, v in state_dict.items(): |
| | name = k[10:] if k.startswith('_orig_mod.') else k |
| | new_state_dict[name] = v |
| | |
| | model.load_state_dict(new_state_dict) |
| | model.to(DEVICE) |
| | model.eval() |
| | |
| | return model, tokenizer, checkpoint.get('val_loss', 0) |
| |
|
| |
|
| | |
| |
|
| |
|
| | @torch.no_grad() |
| | def generate( |
| | model: Luna, |
| | tokenizer: LunaTokenizer, |
| | prompt: str, |
| | max_new_tokens: int = 100, |
| | temperature: float = 0.8, |
| | top_k: int = 40, |
| | top_p: float = 0.9, |
| | repetition_penalty: float = 1.2, |
| | ) -> str: |
| | model.eval() |
| | |
| | tokens = tokenizer.encode(prompt) |
| | if not tokens: |
| | tokens = [{'syllable_id': 0, 'onset_id': 0, 'nucleus_id': 0, 'coda_id': 0, |
| | 'position': 0, 'is_capitalized': 0, 'token_type': 0, |
| | 'has_space_after': 0, 'is_word_end': 0}] |
| | |
| | prompt_len = len(tokens) |
| | recent_tokens = [] |
| | recent_texts = [] |
| | max_len = model.config.max_seq_len |
| | |
| | pad_id = tokenizer.syllable_to_id.get('<pad>', 0) |
| | unk_id = tokenizer.syllable_to_id.get('<unk>', 1) |
| | |
| | bad_single_chars = {sid for syl, sid in tokenizer.syllable_to_id.items() |
| | if len(syl) == 1 and syl.isalpha()} |
| | |
| | for _ in range(max_new_tokens): |
| | if len(tokens) > max_len: |
| | tokens = tokens[-max_len:] |
| | |
| | input_tensor = tokens_to_tensor(tokens, DEVICE) |
| | |
| | with torch.autocast(device_type='cuda' if DEVICE.type == 'cuda' else 'cpu', dtype=torch.bfloat16): |
| | logits, _ = model(input_tensor) |
| | |
| | syl_logits = logits['syllable'][0, -1, :].float() / temperature |
| | |
| | syl_logits[pad_id] = float('-inf') |
| | syl_logits[unk_id] = float('-inf') |
| | for bad_id in bad_single_chars: |
| | syl_logits[bad_id] = float('-inf') |
| | |
| | if top_k > 0: |
| | top_k_val = min(top_k, syl_logits.size(-1)) |
| | values, _ = torch.topk(syl_logits, top_k_val) |
| | syl_logits[syl_logits < values[-1]] = float('-inf') |
| | |
| | if top_p < 1.0: |
| | sorted_logits, sorted_idx = torch.sort(syl_logits, descending=True) |
| | cumsum = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | mask = cumsum > top_p |
| | mask[1:] = mask[:-1].clone() |
| | mask[0] = False |
| | sorted_logits[mask] = float('-inf') |
| | syl_logits = torch.zeros_like(syl_logits).scatter_(-1, sorted_idx, sorted_logits) |
| | |
| | if repetition_penalty > 1.0: |
| | for tid in set(recent_tokens[-30:]): |
| | if 0 <= tid < syl_logits.size(-1): |
| | if syl_logits[tid] > 0: |
| | syl_logits[tid] /= repetition_penalty |
| | else: |
| | syl_logits[tid] *= repetition_penalty |
| | |
| | probs = F.softmax(syl_logits, dim=-1) |
| | if torch.isinf(syl_logits).all(): |
| | break |
| | |
| | next_syl_id = torch.multinomial(probs, 1).item() |
| | syl_text = tokenizer.decode_syllable_id(next_syl_id) |
| | |
| | next_pos = logits['position'][0, -1, :].argmax().item() |
| | next_cap = logits['is_capitalized'][0, -1, :].argmax().item() |
| | next_type = logits['token_type'][0, -1, :].argmax().item() |
| |
|
| | space_probs = torch.softmax(logits['has_space_after'][0, -1, :], dim=-1) |
| | next_space = 1 if space_probs[1] > 0.25 else 0 |
| | |
| | |
| | onset, nucleus, coda = tokenizer._split_onset_nucleus_coda(syl_text) |
| | next_onset = tokenizer._get_id(tokenizer.onset_to_id, onset, tokenizer.unk_onset) |
| | next_nucleus = tokenizer._get_id(tokenizer.nucleus_to_id, nucleus, tokenizer.unk_nucleus) |
| | next_coda = tokenizer._get_id(tokenizer.coda_to_id, coda, tokenizer.unk_coda) |
| | |
| | next_word_end = 1 if (next_pos in [2, 3] or next_space == 1) else 0 |
| | |
| | recent_tokens.append(next_syl_id) |
| | recent_texts.append(syl_text) |
| | |
| | tokens.append({ |
| | 'syllable_id': next_syl_id, |
| | 'onset_id': next_onset, |
| | 'nucleus_id': next_nucleus, |
| | 'coda_id': next_coda, |
| | 'position': next_pos, |
| | 'is_capitalized': next_cap, |
| | 'token_type': next_type, |
| | 'has_space_after': next_space, |
| | 'is_word_end': next_word_end, |
| | }) |
| | |
| | if len(recent_texts) >= 4 and len(set(recent_texts[-4:])) == 1: |
| | break |
| | if len(recent_texts) >= 6: |
| | last_6 = recent_texts[-6:] |
| | if last_6[0] == last_6[2] == last_6[4] and last_6[1] == last_6[3] == last_6[5]: |
| | break |
| | |
| | generated_tokens = tokens[prompt_len:] |
| | prompt_text = decode_tokens(tokenizer, tokens[:prompt_len]) |
| | generated_text = decode_tokens(tokenizer, generated_tokens) |
| | |
| | return prompt_text + generated_text |
| |
|
| |
|
| |
|
| | |
| |
|
| |
|
| | def interactive_mode(model, tokenizer, args): |
| | print("\n" + "=" * 60) |
| | print("Interactive Mode (type 'quit' to exit)") |
| | print("=" * 60) |
| | |
| | while True: |
| | try: |
| | prompt = input("\nPrompt: ").strip() |
| | if prompt.lower() in ('quit', 'exit', 'q'): |
| | break |
| | if not prompt: |
| | continue |
| | |
| | output = generate( |
| | model=model, |
| | tokenizer=tokenizer, |
| | prompt=prompt, |
| | max_new_tokens=args.max_tokens, |
| | temperature=args.temperature, |
| | top_k=args.top_k, |
| | top_p=args.top_p, |
| | repetition_penalty=args.repetition_penalty, |
| | ) |
| | print(f"\n{output}") |
| | except KeyboardInterrupt: |
| | break |
| | |
| | print("\nGoodbye!") |
| |
|
| |
|
| | |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Generate text with Luna") |
| | parser.add_argument("--checkpoint", type=str, required=True) |
| | parser.add_argument("--data_dir", type=str, required=True) |
| | parser.add_argument("--prompt", type=str, default=None) |
| | parser.add_argument("--max_tokens", type=int, default=100) |
| | parser.add_argument("--temperature", type=float, default=0.8) |
| | parser.add_argument("--top_k", type=int, default=40) |
| | parser.add_argument("--top_p", type=float, default=0.9) |
| | parser.add_argument("--repetition_penalty", type=float, default=1.0) |
| | parser.add_argument("--num_samples", type=int, default=1) |
| | parser.add_argument("--seed", type=int, default=None) |
| | |
| | args = parser.parse_args() |
| | |
| | if args.seed is not None: |
| | torch.manual_seed(args.seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(args.seed) |
| | |
| | print("=" * 60) |
| | print("Luna - Text Generation") |
| | print("=" * 60) |
| | print(f"Device: {DEVICE}") |
| | |
| | model, tokenizer, val_loss = load_model(args.checkpoint, args.data_dir) |
| | if val_loss: |
| | print(f"Val loss: {val_loss:.4f}") |
| | |
| | if args.prompt is None: |
| | interactive_mode(model, tokenizer, args) |
| | return |
| | |
| | print(f"\nPrompt: '{args.prompt}'") |
| | print(f"Settings: temp={args.temperature}, top_k={args.top_k}, top_p={args.top_p}") |
| | print("-" * 60) |
| | |
| | for i in range(args.num_samples): |
| | if args.num_samples > 1: |
| | print(f"\n--- Sample {i+1} ---") |
| | output = generate( |
| | model=model, |
| | tokenizer=tokenizer, |
| | prompt=args.prompt, |
| | max_new_tokens=args.max_tokens, |
| | temperature=args.temperature, |
| | top_k=args.top_k, |
| | top_p=args.top_p, |
| | repetition_penalty=args.repetition_penalty, |
| | ) |
| | print(f"\n{output}") |
| | |
| | print("\n" + "=" * 60) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |