import csv import json import re from collections import Counter from pathlib import Path import numpy as np import tensorflow as tf from tensorflow.keras.layers import Dense, Dropout, Embedding, LayerNormalization, MultiHeadAttention from tensorflow.keras.models import Model PAD_TOKEN = "" UNK_TOKEN = "" START_TOKEN = "" END_TOKEN = "" SPECIAL_TOKENS = [PAD_TOKEN, UNK_TOKEN, START_TOKEN, END_TOKEN] TOKEN_PATTERN = re.compile( r"\w+|==|!=|<=|>=|->|=>|::|\+\+|--|&&|\|\||//|/\*|\*/|[^\w]", re.UNICODE, ) TOKENIZER_VERSION = "regex_subword_v1" DEFAULT_TOKEN_CHUNK_SIZE = 2 DEFAULT_WORD_CHUNK_SIZE = 4 DEFAULT_CYRILLIC_CHUNK_SIZE = 2 IDENTIFIER_BOUNDARY_PATTERN = re.compile( r"[А-Яа-яЁё]+|[A-Z]+(?=[A-Z][a-z]|[0-9_]|$)|[A-Z]?[a-z]+|[0-9]+|_+|[^\w]", re.UNICODE, ) COMMON_FALLBACK_TOKENS = list( "abcdefghijklmnopqrstuvwxyz" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "0123456789" "абвгдеёжзийклмнопрстуфхцчшщъыьэюя" "АБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ" "_ \n\t.,!?;:()[]{}<>+-*/=%'\"`\\|&^~#$@" ) SUPPORTED_DATA_EXTENSIONS = {".txt", ".md", ".json", ".jsonl", ".ndjson", ".csv"} TEXT_FIELDS = ("text", "content", "completion", "response", "output", "answer") QUESTION_FIELDS = ("question", "prompt", "instruction", "input", "query") PROMPT_FIELDS = QUESTION_FIELDS ANSWER_FIELDS = ("answer", "completion", "response", "output") QUESTION_LABEL = "Question:" ANSWER_LABEL = "Answer:" TG_DATASET_PRO_ID = "AILaborant/tg_dataset_pro" DEFAULT_COLAB_NOTEBOOK = "tg_dataset_pro_colab_train.ipynb" class ThreeChunkTokenizer: """Regex subword tokenizer tuned for Russian prose and code-like text.""" def __init__( self, vocab=None, max_vocab_size=128000, chunk_size=DEFAULT_TOKEN_CHUNK_SIZE, word_chunk_size=DEFAULT_WORD_CHUNK_SIZE, cyrillic_chunk_size=DEFAULT_CYRILLIC_CHUNK_SIZE, tokenizer_version=TOKENIZER_VERSION, ): self.max_vocab_size = max_vocab_size self.chunk_size = max(1, int(chunk_size)) self.word_chunk_size = max(1, int(word_chunk_size)) self.cyrillic_chunk_size = max(1, int(cyrillic_chunk_size)) self.tokenizer_version = tokenizer_version self.vocab = list(vocab) if vocab else list(SPECIAL_TOKENS) self.token_to_id = {token: idx for idx, token in enumerate(self.vocab)} self.max_token_length = max((len(token) for token in self.vocab), default=1) def _is_fitted(self): return len(self.vocab) > len(SPECIAL_TOKENS) def _is_cyrillic(self, text): return any("а" <= char.lower() <= "я" or char in "Ёё" for char in text) def _chunk_segment(self, segment): if not segment: return [] if segment.startswith("_"): return list(segment) if segment.isdigit(): return [segment[i : i + 4] for i in range(0, len(segment), 4)] chunk_size = self.cyrillic_chunk_size if self._is_cyrillic(segment) else self.word_chunk_size return [segment[i : i + chunk_size] for i in range(0, len(segment), chunk_size)] def _tokenize_word_piece(self, piece): tokens = [] for segment in IDENTIFIER_BOUNDARY_PATTERN.findall(piece): tokens.extend(self._chunk_segment(segment)) return tokens def _tokenize_raw(self, text): tokens = [] for piece in TOKEN_PATTERN.findall(text): if piece.isalnum() or "_" in piece: tokens.extend(self._tokenize_word_piece(piece)) else: tokens.append(piece) return tokens def _tokenize_fixed_chunk_legacy(self, text): tokens = [] for piece in TOKEN_PATTERN.findall(text): if piece.isalnum() or "_" in piece: tokens.extend(piece[i : i + self.chunk_size] for i in range(0, len(piece), self.chunk_size)) else: tokens.append(piece) return tokens def _tokenize_with_vocab(self, token): if token in self.token_to_id: return [token] pieces = [] index = 0 while index < len(token): match = None max_end = min(len(token), index + self.max_token_length) for end in range(max_end, index, -1): candidate = token[index:end] if candidate in self.token_to_id: match = candidate break if match is None: pieces.append(UNK_TOKEN) index += 1 else: pieces.append(match) index += len(match) return pieces def tokenize(self, text): if self.tokenizer_version == "fixed_chunk_legacy": return self._tokenize_fixed_chunk_legacy(text) tokens = self._tokenize_raw(text) if not self._is_fitted(): return tokens segmented = [] for token in tokens: segmented.extend(self._tokenize_with_vocab(token)) return segmented def detokenize(self, tokens): return "".join(token for token in tokens if token not in SPECIAL_TOKENS) def fit(self, texts): counts = Counter() character_counts = Counter() for text in texts: counts.update(self._tokenize_raw(text)) character_counts.update(text) required_tokens = [] seen = set(SPECIAL_TOKENS) for token in COMMON_FALLBACK_TOKENS + [token for token, _ in character_counts.most_common()]: if token not in seen: required_tokens.append(token) seen.add(token) remaining = max(0, self.max_vocab_size - len(SPECIAL_TOKENS) - len(required_tokens)) learned_tokens = [] for token, _ in counts.most_common(): if token in seen: continue learned_tokens.append(token) seen.add(token) if len(learned_tokens) >= remaining: break self.vocab = list(SPECIAL_TOKENS) + required_tokens + learned_tokens self.token_to_id = {token: idx for idx, token in enumerate(self.vocab)} self.max_token_length = max((len(token) for token in self.vocab), default=1) return self def encode(self, text, add_boundaries=True, maxlen=None): tokens = self.tokenize(text) if add_boundaries: tokens = [START_TOKEN] + tokens + [END_TOKEN] ids = [self.token_to_id.get(token, self.token_to_id[UNK_TOKEN]) for token in tokens] if maxlen is not None: ids = ids[:maxlen] ids.extend([self.token_to_id[PAD_TOKEN]] * (maxlen - len(ids))) return ids def decode_ids(self, ids): tokens = [self.vocab[int(idx)] for idx in ids if int(idx) < len(self.vocab)] return self.detokenize(tokens) def to_dict(self): return { "vocab": self.vocab, "max_vocab_size": self.max_vocab_size, "chunk_size": self.chunk_size, "word_chunk_size": self.word_chunk_size, "cyrillic_chunk_size": self.cyrillic_chunk_size, "tokenizer_version": self.tokenizer_version, } @classmethod def from_dict(cls, data): # Older saved artifacts did not store this value and were trained with # 3-character chunks, so keep those models' prompt tokenization stable. return cls( vocab=data["vocab"], max_vocab_size=data.get("max_vocab_size", 128000), chunk_size=data.get("chunk_size", 3), word_chunk_size=data.get("word_chunk_size", data.get("chunk_size", 3)), cyrillic_chunk_size=data.get("cyrillic_chunk_size", data.get("chunk_size", 3)), tokenizer_version=data.get("tokenizer_version", "fixed_chunk_legacy"), ) @tf.keras.utils.register_keras_serializable() class TransformerBlock(tf.keras.layers.Layer): def __init__(self, embed_dim, num_heads, ff_dim, rate=0.2, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim self.num_heads = num_heads self.ff_dim = ff_dim self.rate = rate self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) self.ffn = tf.keras.Sequential([ Dense(ff_dim, activation="relu"), Dense(embed_dim), ]) self.layernorm1 = LayerNormalization(epsilon=1e-5) self.layernorm2 = LayerNormalization(epsilon=1e-5) self.dropout1 = Dropout(rate) self.dropout2 = Dropout(rate) def call(self, inputs, training=None): seq_len = tf.shape(inputs)[1] causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), dtype=tf.bool), -1, 0) attn_output = self.att(inputs, inputs, attention_mask=causal_mask) attn_output = self.dropout1(attn_output, training=training) attn_output = tf.cast(attn_output, inputs.dtype) out1 = self.layernorm1(inputs + attn_output) ffn_output = self.ffn(out1) ffn_output = self.dropout2(ffn_output, training=training) ffn_output = tf.cast(ffn_output, out1.dtype) return self.layernorm2(out1 + ffn_output) def build(self, input_shape): self.att.build(input_shape, input_shape) self.ffn.build(input_shape) self.layernorm1.build(input_shape) self.layernorm2.build(input_shape) super().build(input_shape) def get_config(self): config = super().get_config() config.update({ "embed_dim": self.embed_dim, "num_heads": self.num_heads, "ff_dim": self.ff_dim, "rate": self.rate, }) return config @tf.keras.utils.register_keras_serializable() class TokenAndPositionEmbedding(tf.keras.layers.Layer): def __init__(self, maxlen, vocab_size, embed_dim, **kwargs): super().__init__(**kwargs) self.maxlen = maxlen self.vocab_size = vocab_size self.embed_dim = embed_dim self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim) self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim) def call(self, x): maxlen = tf.shape(x)[-1] positions = tf.range(start=0, limit=maxlen, delta=1) positions = self.pos_emb(positions) x = self.token_emb(x) positions = tf.cast(positions, x.dtype) return x + positions def build(self, input_shape): self.token_emb.build(input_shape) self.pos_emb.build((self.maxlen,)) super().build(input_shape) def get_config(self): config = super().get_config() config.update({ "maxlen": self.maxlen, "vocab_size": self.vocab_size, "embed_dim": self.embed_dim, }) return config def configure_tensorflow(use_mixed_precision=True, use_xla=True): gpus = tf.config.list_physical_devices("GPU") memory_growth_enabled = False for gpu in gpus: try: tf.config.experimental.set_memory_growth(gpu, True) memory_growth_enabled = True except (RuntimeError, ValueError): # Colab may initialize the GPU before training starts. In that case # memory growth cannot be changed, but training can continue safely. pass mixed_precision_enabled = False if use_mixed_precision and gpus: try: tf.keras.mixed_precision.set_global_policy("mixed_float16") mixed_precision_enabled = True except ValueError: pass xla_enabled = False if use_xla: try: tf.config.optimizer.set_jit(True) xla_enabled = True except ValueError: pass return { "gpu_count": len(gpus), "memory_growth": memory_growth_enabled, "mixed_precision": mixed_precision_enabled, "xla": xla_enabled, } def build_model( vocab_size, maxlen=100, embed_dim=96, num_heads=4, ff_dim=192, dropout=0.2, num_layers=1, jit_compile=True, ): inputs = tf.keras.Input(shape=(maxlen,), dtype=tf.int32) x = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)(inputs) for _ in range(max(1, int(num_layers))): x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout)(x) outputs = Dense(vocab_size, activation="softmax", dtype="float32")(x) model = Model(inputs=inputs, outputs=outputs) model.compile( optimizer="adam", loss="sparse_categorical_crossentropy", weighted_metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="answer_accuracy")], jit_compile=jit_compile, ) return model def stringify_value(value): if value is None: return "" if isinstance(value, str): return value.strip() if isinstance(value, (int, float, bool)): return str(value) if isinstance(value, list): return "\n".join(filter(None, (stringify_value(item) for item in value))) if isinstance(value, dict): return record_to_text(value) return str(value).strip() def messages_to_text(messages): lines = [] for message in messages: if not isinstance(message, dict): text = stringify_value(message) if text: lines.append(text) continue role = stringify_value(message.get("role") or message.get("from") or message.get("speaker")) content = stringify_value(message.get("content") or message.get("value") or message.get("text")) if role and content: lines.append(f"{role}: {content}") elif content: lines.append(content) return "\n".join(lines).strip() def first_field(record, field_names): for field_name in field_names: value = stringify_value(record.get(field_name)) if value: return value return "" def format_question_answer(question, answer): return f"{QUESTION_LABEL} {question}\n{ANSWER_LABEL} {answer}" def format_question_prompt(question): return f"{QUESTION_LABEL} {question}\n{ANSWER_LABEL}" def record_to_text(record): if not isinstance(record, dict): return stringify_value(record) question = first_field(record, QUESTION_FIELDS) answer = first_field(record, ANSWER_FIELDS) if question and answer: return format_question_answer(question, answer) if isinstance(record.get("messages"), list): text = messages_to_text(record["messages"]) if text: return text direct_text = first_field(record, TEXT_FIELDS) if direct_text: return direct_text parts = [] for key, value in record.items(): text = stringify_value(value) if text: parts.append(f"{key}: {text}") return "\n".join(parts).strip() def walk_json_records(data): if isinstance(data, list): for item in data: yield from walk_json_records(item) elif isinstance(data, dict): yielded_nested = False for key in ("train", "validation", "test", "data", "examples", "rows", "records"): if isinstance(data.get(key), list): yielded_nested = True yield from walk_json_records(data[key]) if not yielded_nested: yield data else: yield data def read_json_file(path): raw_text = path.read_text(encoding="utf-8-sig") try: data = json.loads(raw_text) return [record_to_text(record) for record in walk_json_records(data)] except json.JSONDecodeError as exc: # Some datasets are JSONL but are still named .json. If parsing a full # JSON document fails, try reading one JSON object per line. if "Extra data" not in exc.msg: raise return read_jsonl_text(raw_text, path) def read_jsonl_text(raw_text, path): texts = [] for line_number, line in enumerate(raw_text.splitlines(), start=1): line = line.strip() if not line: continue try: texts.append(record_to_text(json.loads(line))) except json.JSONDecodeError as exc: raise ValueError(f"Invalid JSON on line {line_number} in {path}") from exc return texts def read_jsonl_file(path): return read_jsonl_text(path.read_text(encoding="utf-8-sig"), path) def read_csv_file(path): with path.open("r", encoding="utf-8", newline="") as file: rows = csv.DictReader(file) return [record_to_text(row) for row in rows] def read_dataset_file(path): extension = path.suffix.lower() if extension in {".txt", ".md"}: return [path.read_text(encoding="utf-8")] if extension == ".json": return read_json_file(path) if extension in {".jsonl", ".ndjson"}: return read_jsonl_file(path) if extension == ".csv": return read_csv_file(path) return [] def read_training_texts(paths): texts = [] for raw_path in paths: path = Path(raw_path) if path.is_dir(): for file_path in sorted(path.rglob("*")): if file_path.is_file() and file_path.suffix.lower() in SUPPORTED_DATA_EXTENSIONS: texts.extend(read_dataset_file(file_path)) else: texts.extend(read_dataset_file(path)) texts = [text for text in texts if text and text.strip()] if not texts: supported = ", ".join(sorted(SUPPORTED_DATA_EXTENSIONS)) raise ValueError(f"No training text found. Supported dataset files: {supported}") return texts def estimate_dataset_defaults(paths): texts = read_training_texts(paths) counter = Counter() token_count = 0 sequence_lengths = [] for text in texts: tokens = ThreeChunkTokenizer().tokenize(text) counter.update(tokens) sequence_length = len(tokens) + 2 sequence_lengths.append(sequence_length) token_count += sequence_length unique_tokens = len(counter) length_p95 = int(np.percentile(sequence_lengths, 95)) if sequence_lengths else 96 recommended_maxlen = min(256, max(96, ((length_p95 + 31) // 32) * 32)) recommended_vocab = min(32000, max(2048, unique_tokens + len(SPECIAL_TOKENS) + 512)) recommended_stride = 48 if token_count >= 1_000_000 else 32 recommended_max_sequences = 60000 if token_count >= 1_000_000 else 0 return { "records": len(texts), "tokens": token_count, "unique_tokens": unique_tokens, "epochs": 3, "batch_size": 128, "validation_split": 0.02, "maxlen": recommended_maxlen, "stride": recommended_stride, "max_vocab_size": recommended_vocab, "max_sequences": recommended_max_sequences, "embed_dim": 96, "num_heads": 4, "num_layers": 2, "ff_dim": 192, "dropout": 0.15, "answer_only_loss": True, } def find_subsequence(values, pattern): if not pattern or len(pattern) > len(values): return -1 for index in range(len(values) - len(pattern) + 1): if values[index : index + len(pattern)] == pattern: return index return -1 def make_loss_mask(ids, tokenizer, answer_only_loss=True): if not answer_only_loss: return np.ones(len(ids), dtype="float32") answer_label_ids = tokenizer.encode(ANSWER_LABEL, add_boundaries=False) answer_label_start = find_subsequence(ids, answer_label_ids) if answer_label_start < 0: return np.ones(len(ids), dtype="float32") answer_start = answer_label_start + len(answer_label_ids) mask = np.zeros(len(ids), dtype="float32") mask[answer_start:] = 1.0 return mask def make_training_arrays(texts, tokenizer, maxlen, stride=48, max_sequences=0, answer_only_loss=True): stride = max(1, int(stride)) max_sequences = max(0, int(max_sequences)) pad_id = tokenizer.token_to_id[PAD_TOKEN] window = maxlen + 1 sequences = [] weights = [] for text in texts: ids = tokenizer.encode(text, add_boundaries=True) loss_mask = make_loss_mask(ids, tokenizer, answer_only_loss=answer_only_loss) if len(ids) < window: pad_amount = window - len(ids) ids.extend([pad_id] * pad_amount) loss_mask = np.pad(loss_mask, (0, pad_amount), constant_values=0.0) last_start = max(0, len(ids) - window) starts = list(range(0, last_start + 1, stride)) if starts[-1] != last_start: starts.append(last_start) for start in starts: target_weights = loss_mask[start + 1 : start + window] if answer_only_loss and not np.any(target_weights): continue sequences.append(ids[start : start + window]) weights.append(target_weights) if max_sequences and len(sequences) >= max_sequences: data = np.asarray(sequences, dtype="int32") sample_weights = np.asarray(weights, dtype="float32") return data[:, :-1], data[:, 1:], sample_weights if not sequences: sequences.append([pad_id] * window) weights.append(np.ones(maxlen, dtype="float32")) data = np.asarray(sequences, dtype="int32") sample_weights = np.asarray(weights, dtype="float32") return data[:, :-1], data[:, 1:], sample_weights def make_tf_datasets(x, y, sample_weights, batch_size, validation_split=0.0, shuffle_buffer=10000): indices = np.random.default_rng(42).permutation(len(x)) x = x[indices] y = y[indices] sample_weights = sample_weights[indices] val_count = int(len(x) * validation_split) if validation_split > 0 and val_count == 0 and len(x) > 1: val_count = 1 if val_count: train_x, val_x = x[:-val_count], x[-val_count:] train_y, val_y = y[:-val_count], y[-val_count:] train_weights, val_weights = sample_weights[:-val_count], sample_weights[-val_count:] else: train_x, train_y = x, y train_weights = sample_weights val_x, val_y = None, None val_weights = None train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y, train_weights)) train_ds = train_ds.shuffle(min(len(train_x), shuffle_buffer), reshuffle_each_iteration=True) train_ds = train_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) val_ds = None if val_x is not None: val_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y, val_weights)) val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) return train_ds, val_ds, len(train_x), 0 if val_x is None else len(val_x) def model_paths(model_dir): path = Path(model_dir) return path / "model.keras", path / "tokenizer.json", path / "config.json" def save_artifacts(model, tokenizer, model_dir, config): path = Path(model_dir) path.mkdir(parents=True, exist_ok=True) model_path, tokenizer_path, config_path = model_paths(path) model.save(model_path) tokenizer_path.write_text(json.dumps(tokenizer.to_dict(), ensure_ascii=False, indent=2), encoding="utf-8") config_path.write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8") def load_artifacts(model_dir="tg-medium"): model_path, tokenizer_path, config_path = model_paths(model_dir) if not model_path.exists(): raise FileNotFoundError(f"Model file not found: {model_path}") if not tokenizer_path.exists(): raise FileNotFoundError(f"Tokenizer file not found: {tokenizer_path}") model = tf.keras.models.load_model( model_path, custom_objects={ "TokenAndPositionEmbedding": TokenAndPositionEmbedding, "TransformerBlock": TransformerBlock, }, compile=False, ) tokenizer = ThreeChunkTokenizer.from_dict(json.loads(tokenizer_path.read_text(encoding="utf-8"))) config = json.loads(config_path.read_text(encoding="utf-8")) if config_path.exists() else {} return model, tokenizer, config def clean_answer_text(text): text = text.strip() for marker in (QUESTION_LABEL, "Prompt:"): marker_index = text.find(marker) if marker_index >= 0: text = text[:marker_index].strip() if text.startswith(ANSWER_LABEL): text = text[len(ANSWER_LABEL) :].strip() return text.strip() def generate_text(model, tokenizer, prompt, maxlen=100, max_new_tokens=80, temperature=1.0, qa_mode=True): if qa_mode: prompt = format_question_prompt(prompt) context = [tokenizer.token_to_id[START_TOKEN]] context.extend(tokenizer.encode(prompt, add_boundaries=False)) end_id = tokenizer.token_to_id[END_TOKEN] pad_id = tokenizer.token_to_id[PAD_TOKEN] generated = [] for _ in range(max_new_tokens): input_ids = context[-maxlen:] last_position = len(input_ids) - 1 padded = input_ids + [pad_id] * (maxlen - len(input_ids)) predictions = model.predict(np.asarray([padded], dtype="int32"), verbose=0)[0, last_position] if temperature <= 0: next_id = int(np.argmax(predictions)) else: logits = np.log(np.maximum(predictions, 1e-9)) / temperature probabilities = np.exp(logits) / np.sum(np.exp(logits)) next_id = int(np.random.choice(len(probabilities), p=probabilities)) if next_id == end_id: break context.append(next_id) generated.append(next_id) output = tokenizer.decode_ids(generated) return clean_answer_text(output) if qa_mode else output def train_command(args): tf_settings = configure_tensorflow( use_mixed_precision=args.mixed_precision, use_xla=args.xla, ) texts = read_training_texts(args.data) tokenizer = ThreeChunkTokenizer(max_vocab_size=args.max_vocab_size).fit(texts) x, y, sample_weights = make_training_arrays( texts, tokenizer, args.maxlen, stride=args.stride, max_sequences=args.max_sequences, answer_only_loss=args.answer_only_loss, ) train_ds, val_ds, train_count, val_count = make_tf_datasets( x, y, sample_weights, batch_size=args.batch_size, validation_split=args.validation_split, ) model = build_model( vocab_size=len(tokenizer.vocab), maxlen=args.maxlen, embed_dim=args.embed_dim, num_heads=args.num_heads, ff_dim=args.ff_dim, dropout=args.dropout, num_layers=args.num_layers, jit_compile=args.xla, ) print( f"Training on {train_count:,} sequences" f"{f' with {val_count:,} validation sequences' if val_count else ''}." ) print( f"Vocab: {len(tokenizer.vocab):,} | Batch: {args.batch_size} | " f"Stride: {args.stride} | Blocks: {args.num_layers} | GPUs: {tf_settings['gpu_count']} | " f"Answer-only loss: {args.answer_only_loss} | " f"Memory growth: {tf_settings['memory_growth']} | " f"Mixed precision: {tf_settings['mixed_precision']} | XLA: {tf_settings['xla']}" ) config = { "maxlen": args.maxlen, "embed_dim": args.embed_dim, "num_heads": args.num_heads, "num_layers": args.num_layers, "ff_dim": args.ff_dim, "dropout": args.dropout, "vocab_size": len(tokenizer.vocab), "tokenizer_chunk_size": tokenizer.chunk_size, "tokenizer_word_chunk_size": tokenizer.word_chunk_size, "tokenizer_cyrillic_chunk_size": tokenizer.cyrillic_chunk_size, "tokenizer_version": tokenizer.tokenizer_version, "stride": args.stride, "max_sequences": args.max_sequences, "answer_only_loss": args.answer_only_loss, "mixed_precision": tf_settings["mixed_precision"], "xla": tf_settings["xla"], } model.fit(train_ds, epochs=args.epochs, validation_data=val_ds) save_artifacts(model, tokenizer, args.model_dir, config) print(f"Saved model, tokenizer, and config to {args.model_dir}") def chat_command(args): model, tokenizer, config = load_artifacts(args.model_dir) maxlen = int(config.get("maxlen", args.maxlen)) print("Chat ready. Type /exit or /quit to stop.") while True: try: prompt = input("you> ").strip() except (EOFError, KeyboardInterrupt): print() break if prompt.lower() in {"/exit", "/quit"}: break response = generate_text(model, tokenizer, prompt, maxlen=maxlen, max_new_tokens=args.max_new_tokens, temperature=args.temperature) print(f"bot> {response}") def generate_command(args): model, tokenizer, config = load_artifacts(args.model_dir) maxlen = int(config.get("maxlen", args.maxlen)) print(generate_text(model, tokenizer, args.prompt, maxlen=maxlen, max_new_tokens=args.max_new_tokens, temperature=args.temperature)) def save_command(args): model, tokenizer, config = load_artifacts(args.model_dir) save_artifacts(model, tokenizer, args.output_dir, config) print(f"Saved copy to {args.output_dir}") def make_notebook_cell(cell_type, source): return { "cell_type": cell_type, "metadata": {}, "source": [line + "\n" for line in source.splitlines()], } def build_colab_notebook(dataset_id=TG_DATASET_PRO_ID): app_source = Path(__file__).read_text(encoding="utf-8") train_code = f""" DATASET_ID = {dataset_id!r} SPLIT = "train" MODEL_DIR = "/content/tg-dataset-pro-model" # 0 means use the full dataset. Set a smaller number for a quick smoke test. MAX_ROWS = 0 # These defaults are intentionally Colab-friendly. Increase MAX_SEQUENCES, # EMBED_DIM, NUM_LAYERS, or EPOCHS if you have a stronger GPU/runtime. EPOCHS = 3 BATCH_SIZE = 128 VALIDATION_SPLIT = 0.02 MAXLEN = 128 STRIDE = 64 MAX_VOCAB_SIZE = 32000 MAX_SEQUENCES = 120000 EMBED_DIM = 128 NUM_HEADS = 4 NUM_LAYERS = 2 FF_DIM = 256 DROPOUT = 0.15 ANSWER_ONLY_LOSS = True from datasets import load_dataset from app import ( ThreeChunkTokenizer, build_model, configure_tensorflow, make_tf_datasets, make_training_arrays, record_to_text, save_artifacts, ) tf_settings = configure_tensorflow(use_mixed_precision=True, use_xla=True) print("TensorFlow settings:", tf_settings) raw_dataset = load_dataset(DATASET_ID, split=SPLIT) texts = [] for index, record in enumerate(raw_dataset): if MAX_ROWS and index >= MAX_ROWS: break text = record_to_text(record) if text and text.strip(): texts.append(text) if not texts: raise ValueError("No text rows were found in the Hugging Face dataset.") print(f"Loaded {{len(texts):,}} training rows from {{DATASET_ID}}.") tokenizer = ThreeChunkTokenizer(max_vocab_size=MAX_VOCAB_SIZE).fit(texts) x, y, sample_weights = make_training_arrays( texts, tokenizer, MAXLEN, stride=STRIDE, max_sequences=MAX_SEQUENCES, answer_only_loss=ANSWER_ONLY_LOSS, ) train_ds, val_ds, train_count, val_count = make_tf_datasets( x, y, sample_weights, batch_size=BATCH_SIZE, validation_split=VALIDATION_SPLIT, ) model = build_model( vocab_size=len(tokenizer.vocab), maxlen=MAXLEN, embed_dim=EMBED_DIM, num_heads=NUM_HEADS, ff_dim=FF_DIM, dropout=DROPOUT, num_layers=NUM_LAYERS, jit_compile=True, ) print( f"Training on {{train_count:,}} sequences" f"{{f' with {{val_count:,}} validation sequences' if val_count else ''}}." ) print(f"Vocab: {{len(tokenizer.vocab):,}} | Maxlen: {{MAXLEN}} | Batch: {{BATCH_SIZE}}") model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds) config = {{ "dataset_id": DATASET_ID, "dataset_split": SPLIT, "max_rows": MAX_ROWS, "maxlen": MAXLEN, "embed_dim": EMBED_DIM, "num_heads": NUM_HEADS, "num_layers": NUM_LAYERS, "ff_dim": FF_DIM, "dropout": DROPOUT, "vocab_size": len(tokenizer.vocab), "tokenizer_chunk_size": tokenizer.chunk_size, "tokenizer_word_chunk_size": tokenizer.word_chunk_size, "tokenizer_cyrillic_chunk_size": tokenizer.cyrillic_chunk_size, "tokenizer_version": tokenizer.tokenizer_version, "stride": STRIDE, "max_sequences": MAX_SEQUENCES, "answer_only_loss": ANSWER_ONLY_LOSS, "mixed_precision": tf_settings["mixed_precision"], "xla": tf_settings["xla"], }} save_artifacts(model, tokenizer, MODEL_DIR, config) print(f"Saved model, tokenizer, and config to {{MODEL_DIR}}") """ zip_code = """ import shutil from google.colab import files archive_path = shutil.make_archive(MODEL_DIR, "zip", MODEL_DIR) print(f"Created archive: {archive_path}") files.download(archive_path) """ return { "nbformat": 4, "nbformat_minor": 5, "metadata": { "accelerator": "GPU", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3", }, "language_info": {"name": "python"}, "colab": {"gpuType": "T4"}, }, "cells": [ make_notebook_cell( "markdown", f"# Train the tiny TG model on `{dataset_id}`\n\n" "Run this notebook in Google Colab with a GPU runtime. " "It downloads the Hugging Face dataset inside Colab, trains the model, " "and saves `model.keras`, `tokenizer.json`, and `config.json`.", ), make_notebook_cell("code", "%pip -q install datasets huggingface_hub"), make_notebook_cell( "code", "from pathlib import Path\n" f"Path('app.py').write_text({app_source!r}, encoding='utf-8')\n" "print('Wrote app.py into the Colab runtime.')", ), make_notebook_cell("code", train_code.strip()), make_notebook_cell( "markdown", "Run the next cell after training if you want to download the saved artifacts as a zip.", ), make_notebook_cell("code", zip_code.strip()), ], } def write_colab_notebook(output_path=DEFAULT_COLAB_NOTEBOOK, dataset_id=TG_DATASET_PRO_ID): notebook = build_colab_notebook(dataset_id=dataset_id) path = Path(output_path) path.write_text(json.dumps(notebook, ensure_ascii=False, indent=2), encoding="utf-8") return path def colab_command(args): notebook_path = write_colab_notebook(args.output, dataset_id=args.dataset_id) print(f"Created Colab notebook: {notebook_path}") print("Upload it to Google Colab, choose a GPU runtime, then run all cells.") def ask_text(prompt, default=None, required=False): while True: suffix = f" [{default}]" if default is not None else "" value = input(f"{prompt}{suffix}: ").strip() if value: return value if default is not None: return str(default) if not required: return "" print("Please enter a value.") def ask_int(prompt, default): while True: value = ask_text(prompt, default) try: return int(value) except ValueError: print("Please enter a whole number.") def ask_float(prompt, default): while True: value = ask_text(prompt, default) try: return float(value) except ValueError: print("Please enter a number.") def ask_bool(prompt, default=True): default_text = "y" if default else "n" while True: value = ask_text(prompt, default_text).lower() if value in {"y", "yes", "true", "1"}: return True if value in {"n", "no", "false", "0"}: return False print("Please enter y or n.") class MenuArgs: pass def build_train_args_from_input(): args = MenuArgs() data = ask_text("Dataset path(s), separated by commas", required=True) args.data = [part.strip() for part in data.split(",") if part.strip()] try: defaults = estimate_dataset_defaults(args.data) print("\nDataset estimate:") print(f"Records: {defaults['records']:,}") print(f"Tokens: {defaults['tokens']:,}") print(f"Unique tokenizer tokens: {defaults['unique_tokens']:,}") print("Recommended fast settings for this dataset are pre-filled below.") except Exception as exc: print(f"Could not estimate dataset defaults, using generic fast defaults: {exc}") defaults = { "epochs": 3, "batch_size": 128, "validation_split": 0.02, "maxlen": 96, "stride": 48, "max_vocab_size": 32000, "max_sequences": 60000, "embed_dim": 96, "num_heads": 4, "num_layers": 2, "ff_dim": 192, "dropout": 0.15, "answer_only_loss": True, } args.model_dir = ask_text("Model save directory", "tg-medium") args.epochs = ask_int("Epochs", defaults["epochs"]) args.batch_size = ask_int("Batch size", defaults["batch_size"]) args.validation_split = ask_float("Validation split", defaults["validation_split"]) args.maxlen = ask_int("Context length / max sequence length", defaults["maxlen"]) args.stride = ask_int("Training stride, higher is faster", defaults["stride"]) args.max_vocab_size = ask_int("Max vocab size", defaults["max_vocab_size"]) args.max_sequences = ask_int("Max training sequences, 0 = all", defaults["max_sequences"]) args.embed_dim = ask_int("Embedding size", defaults["embed_dim"]) args.num_heads = ask_int("Attention heads", defaults["num_heads"]) args.num_layers = ask_int("Transformer blocks", defaults["num_layers"]) args.ff_dim = ask_int("Feed-forward size", defaults["ff_dim"]) args.dropout = ask_float("Dropout", defaults["dropout"]) args.answer_only_loss = ask_bool("Train loss only on answer tokens", defaults["answer_only_loss"]) args.mixed_precision = ask_bool("Use mixed precision on GPU", True) args.xla = ask_bool("Use XLA/JIT compile", True) return args def build_chat_args_from_input(): args = MenuArgs() args.model_dir = ask_text("Model directory to load", "tg-medium") args.maxlen = ask_int("Fallback context length if config is missing", 100) args.max_new_tokens = ask_int("Max new tokens per reply", 80) args.temperature = ask_float("Temperature, 0 for greedy", 0.8) return args def build_generate_args_from_input(): args = build_chat_args_from_input() args.prompt = ask_text("Prompt", required=True) return args def build_save_args_from_input(): args = MenuArgs() args.model_dir = ask_text("Model directory to load", "tg-medium") args.output_dir = ask_text("Output directory for saved copy", required=True) return args def build_colab_args_from_input(): args = MenuArgs() args.dataset_id = ask_text("Hugging Face dataset ID", TG_DATASET_PRO_ID) args.output = ask_text("Output notebook path", DEFAULT_COLAB_NOTEBOOK) return args def print_menu(): print("\nSmall LM CLI") print("1. Train model") print("2. Chat with model") print("3. Generate one response") print("4. Save/copy loaded model artifacts") print("5. Create Colab auto-train notebook") print("6. Exit") def main(): while True: print_menu() choice = input("Choose an option: ").strip().lower() try: if choice in {"1", "train"}: train_command(build_train_args_from_input()) elif choice in {"2", "chat"}: chat_command(build_chat_args_from_input()) elif choice in {"3", "generate"}: generate_command(build_generate_args_from_input()) elif choice in {"4", "save", "copy"}: save_command(build_save_args_from_input()) elif choice in {"5", "colab"}: colab_command(build_colab_args_from_input()) elif choice in {"6", "exit", "quit", "q"}: break else: print("Please choose 1, 2, 3, 4, 5, or 6.") except Exception as exc: print(f"Error: {exc}") if __name__ == "__main__": main()