Spaces:
Running
Running
| import csv | |
| import json | |
| import re | |
| from collections import Counter | |
| from pathlib import Path | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.layers import Dense, Dropout, Embedding, LayerNormalization, MultiHeadAttention | |
| from tensorflow.keras.models import Model | |
| PAD_TOKEN = "<pad>" | |
| UNK_TOKEN = "<unk>" | |
| START_TOKEN = "<start>" | |
| END_TOKEN = "<end>" | |
| SPECIAL_TOKENS = [PAD_TOKEN, UNK_TOKEN, START_TOKEN, END_TOKEN] | |
| TOKEN_PATTERN = re.compile( | |
| r"\w+|==|!=|<=|>=|->|=>|::|\+\+|--|&&|\|\||//|/\*|\*/|[^\w]", | |
| re.UNICODE, | |
| ) | |
| TOKENIZER_VERSION = "regex_subword_v1" | |
| DEFAULT_TOKEN_CHUNK_SIZE = 2 | |
| DEFAULT_WORD_CHUNK_SIZE = 4 | |
| DEFAULT_CYRILLIC_CHUNK_SIZE = 2 | |
| IDENTIFIER_BOUNDARY_PATTERN = re.compile( | |
| r"[А-Яа-яЁё]+|[A-Z]+(?=[A-Z][a-z]|[0-9_]|$)|[A-Z]?[a-z]+|[0-9]+|_+|[^\w]", | |
| re.UNICODE, | |
| ) | |
| COMMON_FALLBACK_TOKENS = list( | |
| "abcdefghijklmnopqrstuvwxyz" | |
| "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
| "0123456789" | |
| "абвгдеёжзийклмнопрстуфхцчшщъыьэюя" | |
| "АБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ" | |
| "_ \n\t.,!?;:()[]{}<>+-*/=%'\"`\\|&^~#$@" | |
| ) | |
| SUPPORTED_DATA_EXTENSIONS = {".txt", ".md", ".json", ".jsonl", ".ndjson", ".csv"} | |
| TEXT_FIELDS = ("text", "content", "completion", "response", "output", "answer") | |
| QUESTION_FIELDS = ("question", "prompt", "instruction", "input", "query") | |
| PROMPT_FIELDS = QUESTION_FIELDS | |
| ANSWER_FIELDS = ("answer", "completion", "response", "output") | |
| QUESTION_LABEL = "Question:" | |
| ANSWER_LABEL = "Answer:" | |
| TG_DATASET_PRO_ID = "AILaborant/tg_dataset_pro" | |
| DEFAULT_COLAB_NOTEBOOK = "tg_dataset_pro_colab_train.ipynb" | |
| class ThreeChunkTokenizer: | |
| """Regex subword tokenizer tuned for Russian prose and code-like text.""" | |
| def __init__( | |
| self, | |
| vocab=None, | |
| max_vocab_size=128000, | |
| chunk_size=DEFAULT_TOKEN_CHUNK_SIZE, | |
| word_chunk_size=DEFAULT_WORD_CHUNK_SIZE, | |
| cyrillic_chunk_size=DEFAULT_CYRILLIC_CHUNK_SIZE, | |
| tokenizer_version=TOKENIZER_VERSION, | |
| ): | |
| self.max_vocab_size = max_vocab_size | |
| self.chunk_size = max(1, int(chunk_size)) | |
| self.word_chunk_size = max(1, int(word_chunk_size)) | |
| self.cyrillic_chunk_size = max(1, int(cyrillic_chunk_size)) | |
| self.tokenizer_version = tokenizer_version | |
| self.vocab = list(vocab) if vocab else list(SPECIAL_TOKENS) | |
| self.token_to_id = {token: idx for idx, token in enumerate(self.vocab)} | |
| self.max_token_length = max((len(token) for token in self.vocab), default=1) | |
| def _is_fitted(self): | |
| return len(self.vocab) > len(SPECIAL_TOKENS) | |
| def _is_cyrillic(self, text): | |
| return any("а" <= char.lower() <= "я" or char in "Ёё" for char in text) | |
| def _chunk_segment(self, segment): | |
| if not segment: | |
| return [] | |
| if segment.startswith("_"): | |
| return list(segment) | |
| if segment.isdigit(): | |
| return [segment[i : i + 4] for i in range(0, len(segment), 4)] | |
| chunk_size = self.cyrillic_chunk_size if self._is_cyrillic(segment) else self.word_chunk_size | |
| return [segment[i : i + chunk_size] for i in range(0, len(segment), chunk_size)] | |
| def _tokenize_word_piece(self, piece): | |
| tokens = [] | |
| for segment in IDENTIFIER_BOUNDARY_PATTERN.findall(piece): | |
| tokens.extend(self._chunk_segment(segment)) | |
| return tokens | |
| def _tokenize_raw(self, text): | |
| tokens = [] | |
| for piece in TOKEN_PATTERN.findall(text): | |
| if piece.isalnum() or "_" in piece: | |
| tokens.extend(self._tokenize_word_piece(piece)) | |
| else: | |
| tokens.append(piece) | |
| return tokens | |
| def _tokenize_fixed_chunk_legacy(self, text): | |
| tokens = [] | |
| for piece in TOKEN_PATTERN.findall(text): | |
| if piece.isalnum() or "_" in piece: | |
| tokens.extend(piece[i : i + self.chunk_size] for i in range(0, len(piece), self.chunk_size)) | |
| else: | |
| tokens.append(piece) | |
| return tokens | |
| def _tokenize_with_vocab(self, token): | |
| if token in self.token_to_id: | |
| return [token] | |
| pieces = [] | |
| index = 0 | |
| while index < len(token): | |
| match = None | |
| max_end = min(len(token), index + self.max_token_length) | |
| for end in range(max_end, index, -1): | |
| candidate = token[index:end] | |
| if candidate in self.token_to_id: | |
| match = candidate | |
| break | |
| if match is None: | |
| pieces.append(UNK_TOKEN) | |
| index += 1 | |
| else: | |
| pieces.append(match) | |
| index += len(match) | |
| return pieces | |
| def tokenize(self, text): | |
| if self.tokenizer_version == "fixed_chunk_legacy": | |
| return self._tokenize_fixed_chunk_legacy(text) | |
| tokens = self._tokenize_raw(text) | |
| if not self._is_fitted(): | |
| return tokens | |
| segmented = [] | |
| for token in tokens: | |
| segmented.extend(self._tokenize_with_vocab(token)) | |
| return segmented | |
| def detokenize(self, tokens): | |
| return "".join(token for token in tokens if token not in SPECIAL_TOKENS) | |
| def fit(self, texts): | |
| counts = Counter() | |
| character_counts = Counter() | |
| for text in texts: | |
| counts.update(self._tokenize_raw(text)) | |
| character_counts.update(text) | |
| required_tokens = [] | |
| seen = set(SPECIAL_TOKENS) | |
| for token in COMMON_FALLBACK_TOKENS + [token for token, _ in character_counts.most_common()]: | |
| if token not in seen: | |
| required_tokens.append(token) | |
| seen.add(token) | |
| remaining = max(0, self.max_vocab_size - len(SPECIAL_TOKENS) - len(required_tokens)) | |
| learned_tokens = [] | |
| for token, _ in counts.most_common(): | |
| if token in seen: | |
| continue | |
| learned_tokens.append(token) | |
| seen.add(token) | |
| if len(learned_tokens) >= remaining: | |
| break | |
| self.vocab = list(SPECIAL_TOKENS) + required_tokens + learned_tokens | |
| self.token_to_id = {token: idx for idx, token in enumerate(self.vocab)} | |
| self.max_token_length = max((len(token) for token in self.vocab), default=1) | |
| return self | |
| def encode(self, text, add_boundaries=True, maxlen=None): | |
| tokens = self.tokenize(text) | |
| if add_boundaries: | |
| tokens = [START_TOKEN] + tokens + [END_TOKEN] | |
| ids = [self.token_to_id.get(token, self.token_to_id[UNK_TOKEN]) for token in tokens] | |
| if maxlen is not None: | |
| ids = ids[:maxlen] | |
| ids.extend([self.token_to_id[PAD_TOKEN]] * (maxlen - len(ids))) | |
| return ids | |
| def decode_ids(self, ids): | |
| tokens = [self.vocab[int(idx)] for idx in ids if int(idx) < len(self.vocab)] | |
| return self.detokenize(tokens) | |
| def to_dict(self): | |
| return { | |
| "vocab": self.vocab, | |
| "max_vocab_size": self.max_vocab_size, | |
| "chunk_size": self.chunk_size, | |
| "word_chunk_size": self.word_chunk_size, | |
| "cyrillic_chunk_size": self.cyrillic_chunk_size, | |
| "tokenizer_version": self.tokenizer_version, | |
| } | |
| def from_dict(cls, data): | |
| # Older saved artifacts did not store this value and were trained with | |
| # 3-character chunks, so keep those models' prompt tokenization stable. | |
| return cls( | |
| vocab=data["vocab"], | |
| max_vocab_size=data.get("max_vocab_size", 128000), | |
| chunk_size=data.get("chunk_size", 3), | |
| word_chunk_size=data.get("word_chunk_size", data.get("chunk_size", 3)), | |
| cyrillic_chunk_size=data.get("cyrillic_chunk_size", data.get("chunk_size", 3)), | |
| tokenizer_version=data.get("tokenizer_version", "fixed_chunk_legacy"), | |
| ) | |
| class TransformerBlock(tf.keras.layers.Layer): | |
| def __init__(self, embed_dim, num_heads, ff_dim, rate=0.2, **kwargs): | |
| super().__init__(**kwargs) | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.ff_dim = ff_dim | |
| self.rate = rate | |
| self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) | |
| self.ffn = tf.keras.Sequential([ | |
| Dense(ff_dim, activation="relu"), | |
| Dense(embed_dim), | |
| ]) | |
| self.layernorm1 = LayerNormalization(epsilon=1e-5) | |
| self.layernorm2 = LayerNormalization(epsilon=1e-5) | |
| self.dropout1 = Dropout(rate) | |
| self.dropout2 = Dropout(rate) | |
| def call(self, inputs, training=None): | |
| seq_len = tf.shape(inputs)[1] | |
| causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), dtype=tf.bool), -1, 0) | |
| attn_output = self.att(inputs, inputs, attention_mask=causal_mask) | |
| attn_output = self.dropout1(attn_output, training=training) | |
| attn_output = tf.cast(attn_output, inputs.dtype) | |
| out1 = self.layernorm1(inputs + attn_output) | |
| ffn_output = self.ffn(out1) | |
| ffn_output = self.dropout2(ffn_output, training=training) | |
| ffn_output = tf.cast(ffn_output, out1.dtype) | |
| return self.layernorm2(out1 + ffn_output) | |
| def build(self, input_shape): | |
| self.att.build(input_shape, input_shape) | |
| self.ffn.build(input_shape) | |
| self.layernorm1.build(input_shape) | |
| self.layernorm2.build(input_shape) | |
| super().build(input_shape) | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({ | |
| "embed_dim": self.embed_dim, | |
| "num_heads": self.num_heads, | |
| "ff_dim": self.ff_dim, | |
| "rate": self.rate, | |
| }) | |
| return config | |
| class TokenAndPositionEmbedding(tf.keras.layers.Layer): | |
| def __init__(self, maxlen, vocab_size, embed_dim, **kwargs): | |
| super().__init__(**kwargs) | |
| self.maxlen = maxlen | |
| self.vocab_size = vocab_size | |
| self.embed_dim = embed_dim | |
| self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim) | |
| self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim) | |
| def call(self, x): | |
| maxlen = tf.shape(x)[-1] | |
| positions = tf.range(start=0, limit=maxlen, delta=1) | |
| positions = self.pos_emb(positions) | |
| x = self.token_emb(x) | |
| positions = tf.cast(positions, x.dtype) | |
| return x + positions | |
| def build(self, input_shape): | |
| self.token_emb.build(input_shape) | |
| self.pos_emb.build((self.maxlen,)) | |
| super().build(input_shape) | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({ | |
| "maxlen": self.maxlen, | |
| "vocab_size": self.vocab_size, | |
| "embed_dim": self.embed_dim, | |
| }) | |
| return config | |
| def configure_tensorflow(use_mixed_precision=True, use_xla=True): | |
| gpus = tf.config.list_physical_devices("GPU") | |
| memory_growth_enabled = False | |
| for gpu in gpus: | |
| try: | |
| tf.config.experimental.set_memory_growth(gpu, True) | |
| memory_growth_enabled = True | |
| except (RuntimeError, ValueError): | |
| # Colab may initialize the GPU before training starts. In that case | |
| # memory growth cannot be changed, but training can continue safely. | |
| pass | |
| mixed_precision_enabled = False | |
| if use_mixed_precision and gpus: | |
| try: | |
| tf.keras.mixed_precision.set_global_policy("mixed_float16") | |
| mixed_precision_enabled = True | |
| except ValueError: | |
| pass | |
| xla_enabled = False | |
| if use_xla: | |
| try: | |
| tf.config.optimizer.set_jit(True) | |
| xla_enabled = True | |
| except ValueError: | |
| pass | |
| return { | |
| "gpu_count": len(gpus), | |
| "memory_growth": memory_growth_enabled, | |
| "mixed_precision": mixed_precision_enabled, | |
| "xla": xla_enabled, | |
| } | |
| def build_model( | |
| vocab_size, | |
| maxlen=100, | |
| embed_dim=96, | |
| num_heads=4, | |
| ff_dim=192, | |
| dropout=0.2, | |
| num_layers=1, | |
| jit_compile=True, | |
| ): | |
| inputs = tf.keras.Input(shape=(maxlen,), dtype=tf.int32) | |
| x = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)(inputs) | |
| for _ in range(max(1, int(num_layers))): | |
| x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout)(x) | |
| outputs = Dense(vocab_size, activation="softmax", dtype="float32")(x) | |
| model = Model(inputs=inputs, outputs=outputs) | |
| model.compile( | |
| optimizer="adam", | |
| loss="sparse_categorical_crossentropy", | |
| weighted_metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="answer_accuracy")], | |
| jit_compile=jit_compile, | |
| ) | |
| return model | |
| def stringify_value(value): | |
| if value is None: | |
| return "" | |
| if isinstance(value, str): | |
| return value.strip() | |
| if isinstance(value, (int, float, bool)): | |
| return str(value) | |
| if isinstance(value, list): | |
| return "\n".join(filter(None, (stringify_value(item) for item in value))) | |
| if isinstance(value, dict): | |
| return record_to_text(value) | |
| return str(value).strip() | |
| def messages_to_text(messages): | |
| lines = [] | |
| for message in messages: | |
| if not isinstance(message, dict): | |
| text = stringify_value(message) | |
| if text: | |
| lines.append(text) | |
| continue | |
| role = stringify_value(message.get("role") or message.get("from") or message.get("speaker")) | |
| content = stringify_value(message.get("content") or message.get("value") or message.get("text")) | |
| if role and content: | |
| lines.append(f"{role}: {content}") | |
| elif content: | |
| lines.append(content) | |
| return "\n".join(lines).strip() | |
| def first_field(record, field_names): | |
| for field_name in field_names: | |
| value = stringify_value(record.get(field_name)) | |
| if value: | |
| return value | |
| return "" | |
| def format_question_answer(question, answer): | |
| return f"{QUESTION_LABEL} {question}\n{ANSWER_LABEL} {answer}" | |
| def format_question_prompt(question): | |
| return f"{QUESTION_LABEL} {question}\n{ANSWER_LABEL}" | |
| def record_to_text(record): | |
| if not isinstance(record, dict): | |
| return stringify_value(record) | |
| question = first_field(record, QUESTION_FIELDS) | |
| answer = first_field(record, ANSWER_FIELDS) | |
| if question and answer: | |
| return format_question_answer(question, answer) | |
| if isinstance(record.get("messages"), list): | |
| text = messages_to_text(record["messages"]) | |
| if text: | |
| return text | |
| direct_text = first_field(record, TEXT_FIELDS) | |
| if direct_text: | |
| return direct_text | |
| parts = [] | |
| for key, value in record.items(): | |
| text = stringify_value(value) | |
| if text: | |
| parts.append(f"{key}: {text}") | |
| return "\n".join(parts).strip() | |
| def walk_json_records(data): | |
| if isinstance(data, list): | |
| for item in data: | |
| yield from walk_json_records(item) | |
| elif isinstance(data, dict): | |
| yielded_nested = False | |
| for key in ("train", "validation", "test", "data", "examples", "rows", "records"): | |
| if isinstance(data.get(key), list): | |
| yielded_nested = True | |
| yield from walk_json_records(data[key]) | |
| if not yielded_nested: | |
| yield data | |
| else: | |
| yield data | |
| def read_json_file(path): | |
| raw_text = path.read_text(encoding="utf-8-sig") | |
| try: | |
| data = json.loads(raw_text) | |
| return [record_to_text(record) for record in walk_json_records(data)] | |
| except json.JSONDecodeError as exc: | |
| # Some datasets are JSONL but are still named .json. If parsing a full | |
| # JSON document fails, try reading one JSON object per line. | |
| if "Extra data" not in exc.msg: | |
| raise | |
| return read_jsonl_text(raw_text, path) | |
| def read_jsonl_text(raw_text, path): | |
| texts = [] | |
| for line_number, line in enumerate(raw_text.splitlines(), start=1): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| texts.append(record_to_text(json.loads(line))) | |
| except json.JSONDecodeError as exc: | |
| raise ValueError(f"Invalid JSON on line {line_number} in {path}") from exc | |
| return texts | |
| def read_jsonl_file(path): | |
| return read_jsonl_text(path.read_text(encoding="utf-8-sig"), path) | |
| def read_csv_file(path): | |
| with path.open("r", encoding="utf-8", newline="") as file: | |
| rows = csv.DictReader(file) | |
| return [record_to_text(row) for row in rows] | |
| def read_dataset_file(path): | |
| extension = path.suffix.lower() | |
| if extension in {".txt", ".md"}: | |
| return [path.read_text(encoding="utf-8")] | |
| if extension == ".json": | |
| return read_json_file(path) | |
| if extension in {".jsonl", ".ndjson"}: | |
| return read_jsonl_file(path) | |
| if extension == ".csv": | |
| return read_csv_file(path) | |
| return [] | |
| def read_training_texts(paths): | |
| texts = [] | |
| for raw_path in paths: | |
| path = Path(raw_path) | |
| if path.is_dir(): | |
| for file_path in sorted(path.rglob("*")): | |
| if file_path.is_file() and file_path.suffix.lower() in SUPPORTED_DATA_EXTENSIONS: | |
| texts.extend(read_dataset_file(file_path)) | |
| else: | |
| texts.extend(read_dataset_file(path)) | |
| texts = [text for text in texts if text and text.strip()] | |
| if not texts: | |
| supported = ", ".join(sorted(SUPPORTED_DATA_EXTENSIONS)) | |
| raise ValueError(f"No training text found. Supported dataset files: {supported}") | |
| return texts | |
| def estimate_dataset_defaults(paths): | |
| texts = read_training_texts(paths) | |
| counter = Counter() | |
| token_count = 0 | |
| sequence_lengths = [] | |
| for text in texts: | |
| tokens = ThreeChunkTokenizer().tokenize(text) | |
| counter.update(tokens) | |
| sequence_length = len(tokens) + 2 | |
| sequence_lengths.append(sequence_length) | |
| token_count += sequence_length | |
| unique_tokens = len(counter) | |
| length_p95 = int(np.percentile(sequence_lengths, 95)) if sequence_lengths else 96 | |
| recommended_maxlen = min(256, max(96, ((length_p95 + 31) // 32) * 32)) | |
| recommended_vocab = min(32000, max(2048, unique_tokens + len(SPECIAL_TOKENS) + 512)) | |
| recommended_stride = 48 if token_count >= 1_000_000 else 32 | |
| recommended_max_sequences = 60000 if token_count >= 1_000_000 else 0 | |
| return { | |
| "records": len(texts), | |
| "tokens": token_count, | |
| "unique_tokens": unique_tokens, | |
| "epochs": 3, | |
| "batch_size": 128, | |
| "validation_split": 0.02, | |
| "maxlen": recommended_maxlen, | |
| "stride": recommended_stride, | |
| "max_vocab_size": recommended_vocab, | |
| "max_sequences": recommended_max_sequences, | |
| "embed_dim": 96, | |
| "num_heads": 4, | |
| "num_layers": 2, | |
| "ff_dim": 192, | |
| "dropout": 0.15, | |
| "answer_only_loss": True, | |
| } | |
| def find_subsequence(values, pattern): | |
| if not pattern or len(pattern) > len(values): | |
| return -1 | |
| for index in range(len(values) - len(pattern) + 1): | |
| if values[index : index + len(pattern)] == pattern: | |
| return index | |
| return -1 | |
| def make_loss_mask(ids, tokenizer, answer_only_loss=True): | |
| if not answer_only_loss: | |
| return np.ones(len(ids), dtype="float32") | |
| answer_label_ids = tokenizer.encode(ANSWER_LABEL, add_boundaries=False) | |
| answer_label_start = find_subsequence(ids, answer_label_ids) | |
| if answer_label_start < 0: | |
| return np.ones(len(ids), dtype="float32") | |
| answer_start = answer_label_start + len(answer_label_ids) | |
| mask = np.zeros(len(ids), dtype="float32") | |
| mask[answer_start:] = 1.0 | |
| return mask | |
| def make_training_arrays(texts, tokenizer, maxlen, stride=48, max_sequences=0, answer_only_loss=True): | |
| stride = max(1, int(stride)) | |
| max_sequences = max(0, int(max_sequences)) | |
| pad_id = tokenizer.token_to_id[PAD_TOKEN] | |
| window = maxlen + 1 | |
| sequences = [] | |
| weights = [] | |
| for text in texts: | |
| ids = tokenizer.encode(text, add_boundaries=True) | |
| loss_mask = make_loss_mask(ids, tokenizer, answer_only_loss=answer_only_loss) | |
| if len(ids) < window: | |
| pad_amount = window - len(ids) | |
| ids.extend([pad_id] * pad_amount) | |
| loss_mask = np.pad(loss_mask, (0, pad_amount), constant_values=0.0) | |
| last_start = max(0, len(ids) - window) | |
| starts = list(range(0, last_start + 1, stride)) | |
| if starts[-1] != last_start: | |
| starts.append(last_start) | |
| for start in starts: | |
| target_weights = loss_mask[start + 1 : start + window] | |
| if answer_only_loss and not np.any(target_weights): | |
| continue | |
| sequences.append(ids[start : start + window]) | |
| weights.append(target_weights) | |
| if max_sequences and len(sequences) >= max_sequences: | |
| data = np.asarray(sequences, dtype="int32") | |
| sample_weights = np.asarray(weights, dtype="float32") | |
| return data[:, :-1], data[:, 1:], sample_weights | |
| if not sequences: | |
| sequences.append([pad_id] * window) | |
| weights.append(np.ones(maxlen, dtype="float32")) | |
| data = np.asarray(sequences, dtype="int32") | |
| sample_weights = np.asarray(weights, dtype="float32") | |
| return data[:, :-1], data[:, 1:], sample_weights | |
| def make_tf_datasets(x, y, sample_weights, batch_size, validation_split=0.0, shuffle_buffer=10000): | |
| indices = np.random.default_rng(42).permutation(len(x)) | |
| x = x[indices] | |
| y = y[indices] | |
| sample_weights = sample_weights[indices] | |
| val_count = int(len(x) * validation_split) | |
| if validation_split > 0 and val_count == 0 and len(x) > 1: | |
| val_count = 1 | |
| if val_count: | |
| train_x, val_x = x[:-val_count], x[-val_count:] | |
| train_y, val_y = y[:-val_count], y[-val_count:] | |
| train_weights, val_weights = sample_weights[:-val_count], sample_weights[-val_count:] | |
| else: | |
| train_x, train_y = x, y | |
| train_weights = sample_weights | |
| val_x, val_y = None, None | |
| val_weights = None | |
| train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y, train_weights)) | |
| train_ds = train_ds.shuffle(min(len(train_x), shuffle_buffer), reshuffle_each_iteration=True) | |
| train_ds = train_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) | |
| val_ds = None | |
| if val_x is not None: | |
| val_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y, val_weights)) | |
| val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) | |
| return train_ds, val_ds, len(train_x), 0 if val_x is None else len(val_x) | |
| def model_paths(model_dir): | |
| path = Path(model_dir) | |
| return path / "model.keras", path / "tokenizer.json", path / "config.json" | |
| def save_artifacts(model, tokenizer, model_dir, config): | |
| path = Path(model_dir) | |
| path.mkdir(parents=True, exist_ok=True) | |
| model_path, tokenizer_path, config_path = model_paths(path) | |
| model.save(model_path) | |
| tokenizer_path.write_text(json.dumps(tokenizer.to_dict(), ensure_ascii=False, indent=2), encoding="utf-8") | |
| config_path.write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8") | |
| def load_artifacts(model_dir="tg-medium"): | |
| model_path, tokenizer_path, config_path = model_paths(model_dir) | |
| if not model_path.exists(): | |
| raise FileNotFoundError(f"Model file not found: {model_path}") | |
| if not tokenizer_path.exists(): | |
| raise FileNotFoundError(f"Tokenizer file not found: {tokenizer_path}") | |
| model = tf.keras.models.load_model( | |
| model_path, | |
| custom_objects={ | |
| "TokenAndPositionEmbedding": TokenAndPositionEmbedding, | |
| "TransformerBlock": TransformerBlock, | |
| }, | |
| compile=False, | |
| ) | |
| tokenizer = ThreeChunkTokenizer.from_dict(json.loads(tokenizer_path.read_text(encoding="utf-8"))) | |
| config = json.loads(config_path.read_text(encoding="utf-8")) if config_path.exists() else {} | |
| return model, tokenizer, config | |
| def clean_answer_text(text): | |
| text = text.strip() | |
| for marker in (QUESTION_LABEL, "Prompt:"): | |
| marker_index = text.find(marker) | |
| if marker_index >= 0: | |
| text = text[:marker_index].strip() | |
| if text.startswith(ANSWER_LABEL): | |
| text = text[len(ANSWER_LABEL) :].strip() | |
| return text.strip() | |
| def generate_text(model, tokenizer, prompt, maxlen=100, max_new_tokens=80, temperature=1.0, qa_mode=True): | |
| if qa_mode: | |
| prompt = format_question_prompt(prompt) | |
| context = [tokenizer.token_to_id[START_TOKEN]] | |
| context.extend(tokenizer.encode(prompt, add_boundaries=False)) | |
| end_id = tokenizer.token_to_id[END_TOKEN] | |
| pad_id = tokenizer.token_to_id[PAD_TOKEN] | |
| generated = [] | |
| for _ in range(max_new_tokens): | |
| input_ids = context[-maxlen:] | |
| last_position = len(input_ids) - 1 | |
| padded = input_ids + [pad_id] * (maxlen - len(input_ids)) | |
| predictions = model.predict(np.asarray([padded], dtype="int32"), verbose=0)[0, last_position] | |
| if temperature <= 0: | |
| next_id = int(np.argmax(predictions)) | |
| else: | |
| logits = np.log(np.maximum(predictions, 1e-9)) / temperature | |
| probabilities = np.exp(logits) / np.sum(np.exp(logits)) | |
| next_id = int(np.random.choice(len(probabilities), p=probabilities)) | |
| if next_id == end_id: | |
| break | |
| context.append(next_id) | |
| generated.append(next_id) | |
| output = tokenizer.decode_ids(generated) | |
| return clean_answer_text(output) if qa_mode else output | |
| def train_command(args): | |
| tf_settings = configure_tensorflow( | |
| use_mixed_precision=args.mixed_precision, | |
| use_xla=args.xla, | |
| ) | |
| texts = read_training_texts(args.data) | |
| tokenizer = ThreeChunkTokenizer(max_vocab_size=args.max_vocab_size).fit(texts) | |
| x, y, sample_weights = make_training_arrays( | |
| texts, | |
| tokenizer, | |
| args.maxlen, | |
| stride=args.stride, | |
| max_sequences=args.max_sequences, | |
| answer_only_loss=args.answer_only_loss, | |
| ) | |
| train_ds, val_ds, train_count, val_count = make_tf_datasets( | |
| x, | |
| y, | |
| sample_weights, | |
| batch_size=args.batch_size, | |
| validation_split=args.validation_split, | |
| ) | |
| model = build_model( | |
| vocab_size=len(tokenizer.vocab), | |
| maxlen=args.maxlen, | |
| embed_dim=args.embed_dim, | |
| num_heads=args.num_heads, | |
| ff_dim=args.ff_dim, | |
| dropout=args.dropout, | |
| num_layers=args.num_layers, | |
| jit_compile=args.xla, | |
| ) | |
| print( | |
| f"Training on {train_count:,} sequences" | |
| f"{f' with {val_count:,} validation sequences' if val_count else ''}." | |
| ) | |
| print( | |
| f"Vocab: {len(tokenizer.vocab):,} | Batch: {args.batch_size} | " | |
| f"Stride: {args.stride} | Blocks: {args.num_layers} | GPUs: {tf_settings['gpu_count']} | " | |
| f"Answer-only loss: {args.answer_only_loss} | " | |
| f"Memory growth: {tf_settings['memory_growth']} | " | |
| f"Mixed precision: {tf_settings['mixed_precision']} | XLA: {tf_settings['xla']}" | |
| ) | |
| config = { | |
| "maxlen": args.maxlen, | |
| "embed_dim": args.embed_dim, | |
| "num_heads": args.num_heads, | |
| "num_layers": args.num_layers, | |
| "ff_dim": args.ff_dim, | |
| "dropout": args.dropout, | |
| "vocab_size": len(tokenizer.vocab), | |
| "tokenizer_chunk_size": tokenizer.chunk_size, | |
| "tokenizer_word_chunk_size": tokenizer.word_chunk_size, | |
| "tokenizer_cyrillic_chunk_size": tokenizer.cyrillic_chunk_size, | |
| "tokenizer_version": tokenizer.tokenizer_version, | |
| "stride": args.stride, | |
| "max_sequences": args.max_sequences, | |
| "answer_only_loss": args.answer_only_loss, | |
| "mixed_precision": tf_settings["mixed_precision"], | |
| "xla": tf_settings["xla"], | |
| } | |
| model.fit(train_ds, epochs=args.epochs, validation_data=val_ds) | |
| save_artifacts(model, tokenizer, args.model_dir, config) | |
| print(f"Saved model, tokenizer, and config to {args.model_dir}") | |
| def chat_command(args): | |
| model, tokenizer, config = load_artifacts(args.model_dir) | |
| maxlen = int(config.get("maxlen", args.maxlen)) | |
| print("Chat ready. Type /exit or /quit to stop.") | |
| while True: | |
| try: | |
| prompt = input("you> ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| break | |
| if prompt.lower() in {"/exit", "/quit"}: | |
| break | |
| response = generate_text(model, tokenizer, prompt, maxlen=maxlen, max_new_tokens=args.max_new_tokens, temperature=args.temperature) | |
| print(f"bot> {response}") | |
| def generate_command(args): | |
| model, tokenizer, config = load_artifacts(args.model_dir) | |
| maxlen = int(config.get("maxlen", args.maxlen)) | |
| print(generate_text(model, tokenizer, args.prompt, maxlen=maxlen, max_new_tokens=args.max_new_tokens, temperature=args.temperature)) | |
| def save_command(args): | |
| model, tokenizer, config = load_artifacts(args.model_dir) | |
| save_artifacts(model, tokenizer, args.output_dir, config) | |
| print(f"Saved copy to {args.output_dir}") | |
| def make_notebook_cell(cell_type, source): | |
| return { | |
| "cell_type": cell_type, | |
| "metadata": {}, | |
| "source": [line + "\n" for line in source.splitlines()], | |
| } | |
| def build_colab_notebook(dataset_id=TG_DATASET_PRO_ID): | |
| app_source = Path(__file__).read_text(encoding="utf-8") | |
| train_code = f""" | |
| DATASET_ID = {dataset_id!r} | |
| SPLIT = "train" | |
| MODEL_DIR = "/content/tg-dataset-pro-model" | |
| # 0 means use the full dataset. Set a smaller number for a quick smoke test. | |
| MAX_ROWS = 0 | |
| # These defaults are intentionally Colab-friendly. Increase MAX_SEQUENCES, | |
| # EMBED_DIM, NUM_LAYERS, or EPOCHS if you have a stronger GPU/runtime. | |
| EPOCHS = 3 | |
| BATCH_SIZE = 128 | |
| VALIDATION_SPLIT = 0.02 | |
| MAXLEN = 128 | |
| STRIDE = 64 | |
| MAX_VOCAB_SIZE = 32000 | |
| MAX_SEQUENCES = 120000 | |
| EMBED_DIM = 128 | |
| NUM_HEADS = 4 | |
| NUM_LAYERS = 2 | |
| FF_DIM = 256 | |
| DROPOUT = 0.15 | |
| ANSWER_ONLY_LOSS = True | |
| from datasets import load_dataset | |
| from app import ( | |
| ThreeChunkTokenizer, | |
| build_model, | |
| configure_tensorflow, | |
| make_tf_datasets, | |
| make_training_arrays, | |
| record_to_text, | |
| save_artifacts, | |
| ) | |
| tf_settings = configure_tensorflow(use_mixed_precision=True, use_xla=True) | |
| print("TensorFlow settings:", tf_settings) | |
| raw_dataset = load_dataset(DATASET_ID, split=SPLIT) | |
| texts = [] | |
| for index, record in enumerate(raw_dataset): | |
| if MAX_ROWS and index >= MAX_ROWS: | |
| break | |
| text = record_to_text(record) | |
| if text and text.strip(): | |
| texts.append(text) | |
| if not texts: | |
| raise ValueError("No text rows were found in the Hugging Face dataset.") | |
| print(f"Loaded {{len(texts):,}} training rows from {{DATASET_ID}}.") | |
| tokenizer = ThreeChunkTokenizer(max_vocab_size=MAX_VOCAB_SIZE).fit(texts) | |
| x, y, sample_weights = make_training_arrays( | |
| texts, | |
| tokenizer, | |
| MAXLEN, | |
| stride=STRIDE, | |
| max_sequences=MAX_SEQUENCES, | |
| answer_only_loss=ANSWER_ONLY_LOSS, | |
| ) | |
| train_ds, val_ds, train_count, val_count = make_tf_datasets( | |
| x, | |
| y, | |
| sample_weights, | |
| batch_size=BATCH_SIZE, | |
| validation_split=VALIDATION_SPLIT, | |
| ) | |
| model = build_model( | |
| vocab_size=len(tokenizer.vocab), | |
| maxlen=MAXLEN, | |
| embed_dim=EMBED_DIM, | |
| num_heads=NUM_HEADS, | |
| ff_dim=FF_DIM, | |
| dropout=DROPOUT, | |
| num_layers=NUM_LAYERS, | |
| jit_compile=True, | |
| ) | |
| print( | |
| f"Training on {{train_count:,}} sequences" | |
| f"{{f' with {{val_count:,}} validation sequences' if val_count else ''}}." | |
| ) | |
| print(f"Vocab: {{len(tokenizer.vocab):,}} | Maxlen: {{MAXLEN}} | Batch: {{BATCH_SIZE}}") | |
| model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds) | |
| config = {{ | |
| "dataset_id": DATASET_ID, | |
| "dataset_split": SPLIT, | |
| "max_rows": MAX_ROWS, | |
| "maxlen": MAXLEN, | |
| "embed_dim": EMBED_DIM, | |
| "num_heads": NUM_HEADS, | |
| "num_layers": NUM_LAYERS, | |
| "ff_dim": FF_DIM, | |
| "dropout": DROPOUT, | |
| "vocab_size": len(tokenizer.vocab), | |
| "tokenizer_chunk_size": tokenizer.chunk_size, | |
| "tokenizer_word_chunk_size": tokenizer.word_chunk_size, | |
| "tokenizer_cyrillic_chunk_size": tokenizer.cyrillic_chunk_size, | |
| "tokenizer_version": tokenizer.tokenizer_version, | |
| "stride": STRIDE, | |
| "max_sequences": MAX_SEQUENCES, | |
| "answer_only_loss": ANSWER_ONLY_LOSS, | |
| "mixed_precision": tf_settings["mixed_precision"], | |
| "xla": tf_settings["xla"], | |
| }} | |
| save_artifacts(model, tokenizer, MODEL_DIR, config) | |
| print(f"Saved model, tokenizer, and config to {{MODEL_DIR}}") | |
| """ | |
| zip_code = """ | |
| import shutil | |
| from google.colab import files | |
| archive_path = shutil.make_archive(MODEL_DIR, "zip", MODEL_DIR) | |
| print(f"Created archive: {archive_path}") | |
| files.download(archive_path) | |
| """ | |
| return { | |
| "nbformat": 4, | |
| "nbformat_minor": 5, | |
| "metadata": { | |
| "accelerator": "GPU", | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3", | |
| }, | |
| "language_info": {"name": "python"}, | |
| "colab": {"gpuType": "T4"}, | |
| }, | |
| "cells": [ | |
| make_notebook_cell( | |
| "markdown", | |
| f"# Train the tiny TG model on `{dataset_id}`\n\n" | |
| "Run this notebook in Google Colab with a GPU runtime. " | |
| "It downloads the Hugging Face dataset inside Colab, trains the model, " | |
| "and saves `model.keras`, `tokenizer.json`, and `config.json`.", | |
| ), | |
| make_notebook_cell("code", "%pip -q install datasets huggingface_hub"), | |
| make_notebook_cell( | |
| "code", | |
| "from pathlib import Path\n" | |
| f"Path('app.py').write_text({app_source!r}, encoding='utf-8')\n" | |
| "print('Wrote app.py into the Colab runtime.')", | |
| ), | |
| make_notebook_cell("code", train_code.strip()), | |
| make_notebook_cell( | |
| "markdown", | |
| "Run the next cell after training if you want to download the saved artifacts as a zip.", | |
| ), | |
| make_notebook_cell("code", zip_code.strip()), | |
| ], | |
| } | |
| def write_colab_notebook(output_path=DEFAULT_COLAB_NOTEBOOK, dataset_id=TG_DATASET_PRO_ID): | |
| notebook = build_colab_notebook(dataset_id=dataset_id) | |
| path = Path(output_path) | |
| path.write_text(json.dumps(notebook, ensure_ascii=False, indent=2), encoding="utf-8") | |
| return path | |
| def colab_command(args): | |
| notebook_path = write_colab_notebook(args.output, dataset_id=args.dataset_id) | |
| print(f"Created Colab notebook: {notebook_path}") | |
| print("Upload it to Google Colab, choose a GPU runtime, then run all cells.") | |
| def ask_text(prompt, default=None, required=False): | |
| while True: | |
| suffix = f" [{default}]" if default is not None else "" | |
| value = input(f"{prompt}{suffix}: ").strip() | |
| if value: | |
| return value | |
| if default is not None: | |
| return str(default) | |
| if not required: | |
| return "" | |
| print("Please enter a value.") | |
| def ask_int(prompt, default): | |
| while True: | |
| value = ask_text(prompt, default) | |
| try: | |
| return int(value) | |
| except ValueError: | |
| print("Please enter a whole number.") | |
| def ask_float(prompt, default): | |
| while True: | |
| value = ask_text(prompt, default) | |
| try: | |
| return float(value) | |
| except ValueError: | |
| print("Please enter a number.") | |
| def ask_bool(prompt, default=True): | |
| default_text = "y" if default else "n" | |
| while True: | |
| value = ask_text(prompt, default_text).lower() | |
| if value in {"y", "yes", "true", "1"}: | |
| return True | |
| if value in {"n", "no", "false", "0"}: | |
| return False | |
| print("Please enter y or n.") | |
| class MenuArgs: | |
| pass | |
| def build_train_args_from_input(): | |
| args = MenuArgs() | |
| data = ask_text("Dataset path(s), separated by commas", required=True) | |
| args.data = [part.strip() for part in data.split(",") if part.strip()] | |
| try: | |
| defaults = estimate_dataset_defaults(args.data) | |
| print("\nDataset estimate:") | |
| print(f"Records: {defaults['records']:,}") | |
| print(f"Tokens: {defaults['tokens']:,}") | |
| print(f"Unique tokenizer tokens: {defaults['unique_tokens']:,}") | |
| print("Recommended fast settings for this dataset are pre-filled below.") | |
| except Exception as exc: | |
| print(f"Could not estimate dataset defaults, using generic fast defaults: {exc}") | |
| defaults = { | |
| "epochs": 3, | |
| "batch_size": 128, | |
| "validation_split": 0.02, | |
| "maxlen": 96, | |
| "stride": 48, | |
| "max_vocab_size": 32000, | |
| "max_sequences": 60000, | |
| "embed_dim": 96, | |
| "num_heads": 4, | |
| "num_layers": 2, | |
| "ff_dim": 192, | |
| "dropout": 0.15, | |
| "answer_only_loss": True, | |
| } | |
| args.model_dir = ask_text("Model save directory", "tg-medium") | |
| args.epochs = ask_int("Epochs", defaults["epochs"]) | |
| args.batch_size = ask_int("Batch size", defaults["batch_size"]) | |
| args.validation_split = ask_float("Validation split", defaults["validation_split"]) | |
| args.maxlen = ask_int("Context length / max sequence length", defaults["maxlen"]) | |
| args.stride = ask_int("Training stride, higher is faster", defaults["stride"]) | |
| args.max_vocab_size = ask_int("Max vocab size", defaults["max_vocab_size"]) | |
| args.max_sequences = ask_int("Max training sequences, 0 = all", defaults["max_sequences"]) | |
| args.embed_dim = ask_int("Embedding size", defaults["embed_dim"]) | |
| args.num_heads = ask_int("Attention heads", defaults["num_heads"]) | |
| args.num_layers = ask_int("Transformer blocks", defaults["num_layers"]) | |
| args.ff_dim = ask_int("Feed-forward size", defaults["ff_dim"]) | |
| args.dropout = ask_float("Dropout", defaults["dropout"]) | |
| args.answer_only_loss = ask_bool("Train loss only on answer tokens", defaults["answer_only_loss"]) | |
| args.mixed_precision = ask_bool("Use mixed precision on GPU", True) | |
| args.xla = ask_bool("Use XLA/JIT compile", True) | |
| return args | |
| def build_chat_args_from_input(): | |
| args = MenuArgs() | |
| args.model_dir = ask_text("Model directory to load", "tg-medium") | |
| args.maxlen = ask_int("Fallback context length if config is missing", 100) | |
| args.max_new_tokens = ask_int("Max new tokens per reply", 80) | |
| args.temperature = ask_float("Temperature, 0 for greedy", 0.8) | |
| return args | |
| def build_generate_args_from_input(): | |
| args = build_chat_args_from_input() | |
| args.prompt = ask_text("Prompt", required=True) | |
| return args | |
| def build_save_args_from_input(): | |
| args = MenuArgs() | |
| args.model_dir = ask_text("Model directory to load", "tg-medium") | |
| args.output_dir = ask_text("Output directory for saved copy", required=True) | |
| return args | |
| def build_colab_args_from_input(): | |
| args = MenuArgs() | |
| args.dataset_id = ask_text("Hugging Face dataset ID", TG_DATASET_PRO_ID) | |
| args.output = ask_text("Output notebook path", DEFAULT_COLAB_NOTEBOOK) | |
| return args | |
| def print_menu(): | |
| print("\nSmall LM CLI") | |
| print("1. Train model") | |
| print("2. Chat with model") | |
| print("3. Generate one response") | |
| print("4. Save/copy loaded model artifacts") | |
| print("5. Create Colab auto-train notebook") | |
| print("6. Exit") | |
| def main(): | |
| while True: | |
| print_menu() | |
| choice = input("Choose an option: ").strip().lower() | |
| try: | |
| if choice in {"1", "train"}: | |
| train_command(build_train_args_from_input()) | |
| elif choice in {"2", "chat"}: | |
| chat_command(build_chat_args_from_input()) | |
| elif choice in {"3", "generate"}: | |
| generate_command(build_generate_args_from_input()) | |
| elif choice in {"4", "save", "copy"}: | |
| save_command(build_save_args_from_input()) | |
| elif choice in {"5", "colab"}: | |
| colab_command(build_colab_args_from_input()) | |
| elif choice in {"6", "exit", "quit", "q"}: | |
| break | |
| else: | |
| print("Please choose 1, 2, 3, 4, 5, or 6.") | |
| except Exception as exc: | |
| print(f"Error: {exc}") | |
| if __name__ == "__main__": | |
| main() | |