tg_medium / app_internal.py
AILaborant's picture
Create app_internal.py
9f6c3b2 verified
import csv
import json
import re
from collections import Counter
from pathlib import Path
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout, Embedding, LayerNormalization, MultiHeadAttention
from tensorflow.keras.models import Model
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"
START_TOKEN = "<start>"
END_TOKEN = "<end>"
SPECIAL_TOKENS = [PAD_TOKEN, UNK_TOKEN, START_TOKEN, END_TOKEN]
TOKEN_PATTERN = re.compile(
r"\w+|==|!=|<=|>=|->|=>|::|\+\+|--|&&|\|\||//|/\*|\*/|[^\w]",
re.UNICODE,
)
TOKENIZER_VERSION = "regex_subword_v1"
DEFAULT_TOKEN_CHUNK_SIZE = 2
DEFAULT_WORD_CHUNK_SIZE = 4
DEFAULT_CYRILLIC_CHUNK_SIZE = 2
IDENTIFIER_BOUNDARY_PATTERN = re.compile(
r"[А-Яа-яЁё]+|[A-Z]+(?=[A-Z][a-z]|[0-9_]|$)|[A-Z]?[a-z]+|[0-9]+|_+|[^\w]",
re.UNICODE,
)
COMMON_FALLBACK_TOKENS = list(
"abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"0123456789"
"абвгдеёжзийклмнопрстуфхцчшщъыьэюя"
"АБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
"_ \n\t.,!?;:()[]{}<>+-*/=%'\"`\\|&^~#$@"
)
SUPPORTED_DATA_EXTENSIONS = {".txt", ".md", ".json", ".jsonl", ".ndjson", ".csv"}
TEXT_FIELDS = ("text", "content", "completion", "response", "output", "answer")
QUESTION_FIELDS = ("question", "prompt", "instruction", "input", "query")
PROMPT_FIELDS = QUESTION_FIELDS
ANSWER_FIELDS = ("answer", "completion", "response", "output")
QUESTION_LABEL = "Question:"
ANSWER_LABEL = "Answer:"
TG_DATASET_PRO_ID = "AILaborant/tg_dataset_pro"
DEFAULT_COLAB_NOTEBOOK = "tg_dataset_pro_colab_train.ipynb"
class ThreeChunkTokenizer:
"""Regex subword tokenizer tuned for Russian prose and code-like text."""
def __init__(
self,
vocab=None,
max_vocab_size=128000,
chunk_size=DEFAULT_TOKEN_CHUNK_SIZE,
word_chunk_size=DEFAULT_WORD_CHUNK_SIZE,
cyrillic_chunk_size=DEFAULT_CYRILLIC_CHUNK_SIZE,
tokenizer_version=TOKENIZER_VERSION,
):
self.max_vocab_size = max_vocab_size
self.chunk_size = max(1, int(chunk_size))
self.word_chunk_size = max(1, int(word_chunk_size))
self.cyrillic_chunk_size = max(1, int(cyrillic_chunk_size))
self.tokenizer_version = tokenizer_version
self.vocab = list(vocab) if vocab else list(SPECIAL_TOKENS)
self.token_to_id = {token: idx for idx, token in enumerate(self.vocab)}
self.max_token_length = max((len(token) for token in self.vocab), default=1)
def _is_fitted(self):
return len(self.vocab) > len(SPECIAL_TOKENS)
def _is_cyrillic(self, text):
return any("а" <= char.lower() <= "я" or char in "Ёё" for char in text)
def _chunk_segment(self, segment):
if not segment:
return []
if segment.startswith("_"):
return list(segment)
if segment.isdigit():
return [segment[i : i + 4] for i in range(0, len(segment), 4)]
chunk_size = self.cyrillic_chunk_size if self._is_cyrillic(segment) else self.word_chunk_size
return [segment[i : i + chunk_size] for i in range(0, len(segment), chunk_size)]
def _tokenize_word_piece(self, piece):
tokens = []
for segment in IDENTIFIER_BOUNDARY_PATTERN.findall(piece):
tokens.extend(self._chunk_segment(segment))
return tokens
def _tokenize_raw(self, text):
tokens = []
for piece in TOKEN_PATTERN.findall(text):
if piece.isalnum() or "_" in piece:
tokens.extend(self._tokenize_word_piece(piece))
else:
tokens.append(piece)
return tokens
def _tokenize_fixed_chunk_legacy(self, text):
tokens = []
for piece in TOKEN_PATTERN.findall(text):
if piece.isalnum() or "_" in piece:
tokens.extend(piece[i : i + self.chunk_size] for i in range(0, len(piece), self.chunk_size))
else:
tokens.append(piece)
return tokens
def _tokenize_with_vocab(self, token):
if token in self.token_to_id:
return [token]
pieces = []
index = 0
while index < len(token):
match = None
max_end = min(len(token), index + self.max_token_length)
for end in range(max_end, index, -1):
candidate = token[index:end]
if candidate in self.token_to_id:
match = candidate
break
if match is None:
pieces.append(UNK_TOKEN)
index += 1
else:
pieces.append(match)
index += len(match)
return pieces
def tokenize(self, text):
if self.tokenizer_version == "fixed_chunk_legacy":
return self._tokenize_fixed_chunk_legacy(text)
tokens = self._tokenize_raw(text)
if not self._is_fitted():
return tokens
segmented = []
for token in tokens:
segmented.extend(self._tokenize_with_vocab(token))
return segmented
def detokenize(self, tokens):
return "".join(token for token in tokens if token not in SPECIAL_TOKENS)
def fit(self, texts):
counts = Counter()
character_counts = Counter()
for text in texts:
counts.update(self._tokenize_raw(text))
character_counts.update(text)
required_tokens = []
seen = set(SPECIAL_TOKENS)
for token in COMMON_FALLBACK_TOKENS + [token for token, _ in character_counts.most_common()]:
if token not in seen:
required_tokens.append(token)
seen.add(token)
remaining = max(0, self.max_vocab_size - len(SPECIAL_TOKENS) - len(required_tokens))
learned_tokens = []
for token, _ in counts.most_common():
if token in seen:
continue
learned_tokens.append(token)
seen.add(token)
if len(learned_tokens) >= remaining:
break
self.vocab = list(SPECIAL_TOKENS) + required_tokens + learned_tokens
self.token_to_id = {token: idx for idx, token in enumerate(self.vocab)}
self.max_token_length = max((len(token) for token in self.vocab), default=1)
return self
def encode(self, text, add_boundaries=True, maxlen=None):
tokens = self.tokenize(text)
if add_boundaries:
tokens = [START_TOKEN] + tokens + [END_TOKEN]
ids = [self.token_to_id.get(token, self.token_to_id[UNK_TOKEN]) for token in tokens]
if maxlen is not None:
ids = ids[:maxlen]
ids.extend([self.token_to_id[PAD_TOKEN]] * (maxlen - len(ids)))
return ids
def decode_ids(self, ids):
tokens = [self.vocab[int(idx)] for idx in ids if int(idx) < len(self.vocab)]
return self.detokenize(tokens)
def to_dict(self):
return {
"vocab": self.vocab,
"max_vocab_size": self.max_vocab_size,
"chunk_size": self.chunk_size,
"word_chunk_size": self.word_chunk_size,
"cyrillic_chunk_size": self.cyrillic_chunk_size,
"tokenizer_version": self.tokenizer_version,
}
@classmethod
def from_dict(cls, data):
# Older saved artifacts did not store this value and were trained with
# 3-character chunks, so keep those models' prompt tokenization stable.
return cls(
vocab=data["vocab"],
max_vocab_size=data.get("max_vocab_size", 128000),
chunk_size=data.get("chunk_size", 3),
word_chunk_size=data.get("word_chunk_size", data.get("chunk_size", 3)),
cyrillic_chunk_size=data.get("cyrillic_chunk_size", data.get("chunk_size", 3)),
tokenizer_version=data.get("tokenizer_version", "fixed_chunk_legacy"),
)
@tf.keras.utils.register_keras_serializable()
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.2, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.ff_dim = ff_dim
self.rate = rate
self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.ffn = tf.keras.Sequential([
Dense(ff_dim, activation="relu"),
Dense(embed_dim),
])
self.layernorm1 = LayerNormalization(epsilon=1e-5)
self.layernorm2 = LayerNormalization(epsilon=1e-5)
self.dropout1 = Dropout(rate)
self.dropout2 = Dropout(rate)
def call(self, inputs, training=None):
seq_len = tf.shape(inputs)[1]
causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), dtype=tf.bool), -1, 0)
attn_output = self.att(inputs, inputs, attention_mask=causal_mask)
attn_output = self.dropout1(attn_output, training=training)
attn_output = tf.cast(attn_output, inputs.dtype)
out1 = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
ffn_output = tf.cast(ffn_output, out1.dtype)
return self.layernorm2(out1 + ffn_output)
def build(self, input_shape):
self.att.build(input_shape, input_shape)
self.ffn.build(input_shape)
self.layernorm1.build(input_shape)
self.layernorm2.build(input_shape)
super().build(input_shape)
def get_config(self):
config = super().get_config()
config.update({
"embed_dim": self.embed_dim,
"num_heads": self.num_heads,
"ff_dim": self.ff_dim,
"rate": self.rate,
})
return config
@tf.keras.utils.register_keras_serializable()
class TokenAndPositionEmbedding(tf.keras.layers.Layer):
def __init__(self, maxlen, vocab_size, embed_dim, **kwargs):
super().__init__(**kwargs)
self.maxlen = maxlen
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim)
self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim)
def call(self, x):
maxlen = tf.shape(x)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
positions = tf.cast(positions, x.dtype)
return x + positions
def build(self, input_shape):
self.token_emb.build(input_shape)
self.pos_emb.build((self.maxlen,))
super().build(input_shape)
def get_config(self):
config = super().get_config()
config.update({
"maxlen": self.maxlen,
"vocab_size": self.vocab_size,
"embed_dim": self.embed_dim,
})
return config
def configure_tensorflow(use_mixed_precision=True, use_xla=True):
gpus = tf.config.list_physical_devices("GPU")
memory_growth_enabled = False
for gpu in gpus:
try:
tf.config.experimental.set_memory_growth(gpu, True)
memory_growth_enabled = True
except (RuntimeError, ValueError):
# Colab may initialize the GPU before training starts. In that case
# memory growth cannot be changed, but training can continue safely.
pass
mixed_precision_enabled = False
if use_mixed_precision and gpus:
try:
tf.keras.mixed_precision.set_global_policy("mixed_float16")
mixed_precision_enabled = True
except ValueError:
pass
xla_enabled = False
if use_xla:
try:
tf.config.optimizer.set_jit(True)
xla_enabled = True
except ValueError:
pass
return {
"gpu_count": len(gpus),
"memory_growth": memory_growth_enabled,
"mixed_precision": mixed_precision_enabled,
"xla": xla_enabled,
}
def build_model(
vocab_size,
maxlen=100,
embed_dim=96,
num_heads=4,
ff_dim=192,
dropout=0.2,
num_layers=1,
jit_compile=True,
):
inputs = tf.keras.Input(shape=(maxlen,), dtype=tf.int32)
x = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)(inputs)
for _ in range(max(1, int(num_layers))):
x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout)(x)
outputs = Dense(vocab_size, activation="softmax", dtype="float32")(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
weighted_metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="answer_accuracy")],
jit_compile=jit_compile,
)
return model
def stringify_value(value):
if value is None:
return ""
if isinstance(value, str):
return value.strip()
if isinstance(value, (int, float, bool)):
return str(value)
if isinstance(value, list):
return "\n".join(filter(None, (stringify_value(item) for item in value)))
if isinstance(value, dict):
return record_to_text(value)
return str(value).strip()
def messages_to_text(messages):
lines = []
for message in messages:
if not isinstance(message, dict):
text = stringify_value(message)
if text:
lines.append(text)
continue
role = stringify_value(message.get("role") or message.get("from") or message.get("speaker"))
content = stringify_value(message.get("content") or message.get("value") or message.get("text"))
if role and content:
lines.append(f"{role}: {content}")
elif content:
lines.append(content)
return "\n".join(lines).strip()
def first_field(record, field_names):
for field_name in field_names:
value = stringify_value(record.get(field_name))
if value:
return value
return ""
def format_question_answer(question, answer):
return f"{QUESTION_LABEL} {question}\n{ANSWER_LABEL} {answer}"
def format_question_prompt(question):
return f"{QUESTION_LABEL} {question}\n{ANSWER_LABEL}"
def record_to_text(record):
if not isinstance(record, dict):
return stringify_value(record)
question = first_field(record, QUESTION_FIELDS)
answer = first_field(record, ANSWER_FIELDS)
if question and answer:
return format_question_answer(question, answer)
if isinstance(record.get("messages"), list):
text = messages_to_text(record["messages"])
if text:
return text
direct_text = first_field(record, TEXT_FIELDS)
if direct_text:
return direct_text
parts = []
for key, value in record.items():
text = stringify_value(value)
if text:
parts.append(f"{key}: {text}")
return "\n".join(parts).strip()
def walk_json_records(data):
if isinstance(data, list):
for item in data:
yield from walk_json_records(item)
elif isinstance(data, dict):
yielded_nested = False
for key in ("train", "validation", "test", "data", "examples", "rows", "records"):
if isinstance(data.get(key), list):
yielded_nested = True
yield from walk_json_records(data[key])
if not yielded_nested:
yield data
else:
yield data
def read_json_file(path):
raw_text = path.read_text(encoding="utf-8-sig")
try:
data = json.loads(raw_text)
return [record_to_text(record) for record in walk_json_records(data)]
except json.JSONDecodeError as exc:
# Some datasets are JSONL but are still named .json. If parsing a full
# JSON document fails, try reading one JSON object per line.
if "Extra data" not in exc.msg:
raise
return read_jsonl_text(raw_text, path)
def read_jsonl_text(raw_text, path):
texts = []
for line_number, line in enumerate(raw_text.splitlines(), start=1):
line = line.strip()
if not line:
continue
try:
texts.append(record_to_text(json.loads(line)))
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid JSON on line {line_number} in {path}") from exc
return texts
def read_jsonl_file(path):
return read_jsonl_text(path.read_text(encoding="utf-8-sig"), path)
def read_csv_file(path):
with path.open("r", encoding="utf-8", newline="") as file:
rows = csv.DictReader(file)
return [record_to_text(row) for row in rows]
def read_dataset_file(path):
extension = path.suffix.lower()
if extension in {".txt", ".md"}:
return [path.read_text(encoding="utf-8")]
if extension == ".json":
return read_json_file(path)
if extension in {".jsonl", ".ndjson"}:
return read_jsonl_file(path)
if extension == ".csv":
return read_csv_file(path)
return []
def read_training_texts(paths):
texts = []
for raw_path in paths:
path = Path(raw_path)
if path.is_dir():
for file_path in sorted(path.rglob("*")):
if file_path.is_file() and file_path.suffix.lower() in SUPPORTED_DATA_EXTENSIONS:
texts.extend(read_dataset_file(file_path))
else:
texts.extend(read_dataset_file(path))
texts = [text for text in texts if text and text.strip()]
if not texts:
supported = ", ".join(sorted(SUPPORTED_DATA_EXTENSIONS))
raise ValueError(f"No training text found. Supported dataset files: {supported}")
return texts
def estimate_dataset_defaults(paths):
texts = read_training_texts(paths)
counter = Counter()
token_count = 0
sequence_lengths = []
for text in texts:
tokens = ThreeChunkTokenizer().tokenize(text)
counter.update(tokens)
sequence_length = len(tokens) + 2
sequence_lengths.append(sequence_length)
token_count += sequence_length
unique_tokens = len(counter)
length_p95 = int(np.percentile(sequence_lengths, 95)) if sequence_lengths else 96
recommended_maxlen = min(256, max(96, ((length_p95 + 31) // 32) * 32))
recommended_vocab = min(32000, max(2048, unique_tokens + len(SPECIAL_TOKENS) + 512))
recommended_stride = 48 if token_count >= 1_000_000 else 32
recommended_max_sequences = 60000 if token_count >= 1_000_000 else 0
return {
"records": len(texts),
"tokens": token_count,
"unique_tokens": unique_tokens,
"epochs": 3,
"batch_size": 128,
"validation_split": 0.02,
"maxlen": recommended_maxlen,
"stride": recommended_stride,
"max_vocab_size": recommended_vocab,
"max_sequences": recommended_max_sequences,
"embed_dim": 96,
"num_heads": 4,
"num_layers": 2,
"ff_dim": 192,
"dropout": 0.15,
"answer_only_loss": True,
}
def find_subsequence(values, pattern):
if not pattern or len(pattern) > len(values):
return -1
for index in range(len(values) - len(pattern) + 1):
if values[index : index + len(pattern)] == pattern:
return index
return -1
def make_loss_mask(ids, tokenizer, answer_only_loss=True):
if not answer_only_loss:
return np.ones(len(ids), dtype="float32")
answer_label_ids = tokenizer.encode(ANSWER_LABEL, add_boundaries=False)
answer_label_start = find_subsequence(ids, answer_label_ids)
if answer_label_start < 0:
return np.ones(len(ids), dtype="float32")
answer_start = answer_label_start + len(answer_label_ids)
mask = np.zeros(len(ids), dtype="float32")
mask[answer_start:] = 1.0
return mask
def make_training_arrays(texts, tokenizer, maxlen, stride=48, max_sequences=0, answer_only_loss=True):
stride = max(1, int(stride))
max_sequences = max(0, int(max_sequences))
pad_id = tokenizer.token_to_id[PAD_TOKEN]
window = maxlen + 1
sequences = []
weights = []
for text in texts:
ids = tokenizer.encode(text, add_boundaries=True)
loss_mask = make_loss_mask(ids, tokenizer, answer_only_loss=answer_only_loss)
if len(ids) < window:
pad_amount = window - len(ids)
ids.extend([pad_id] * pad_amount)
loss_mask = np.pad(loss_mask, (0, pad_amount), constant_values=0.0)
last_start = max(0, len(ids) - window)
starts = list(range(0, last_start + 1, stride))
if starts[-1] != last_start:
starts.append(last_start)
for start in starts:
target_weights = loss_mask[start + 1 : start + window]
if answer_only_loss and not np.any(target_weights):
continue
sequences.append(ids[start : start + window])
weights.append(target_weights)
if max_sequences and len(sequences) >= max_sequences:
data = np.asarray(sequences, dtype="int32")
sample_weights = np.asarray(weights, dtype="float32")
return data[:, :-1], data[:, 1:], sample_weights
if not sequences:
sequences.append([pad_id] * window)
weights.append(np.ones(maxlen, dtype="float32"))
data = np.asarray(sequences, dtype="int32")
sample_weights = np.asarray(weights, dtype="float32")
return data[:, :-1], data[:, 1:], sample_weights
def make_tf_datasets(x, y, sample_weights, batch_size, validation_split=0.0, shuffle_buffer=10000):
indices = np.random.default_rng(42).permutation(len(x))
x = x[indices]
y = y[indices]
sample_weights = sample_weights[indices]
val_count = int(len(x) * validation_split)
if validation_split > 0 and val_count == 0 and len(x) > 1:
val_count = 1
if val_count:
train_x, val_x = x[:-val_count], x[-val_count:]
train_y, val_y = y[:-val_count], y[-val_count:]
train_weights, val_weights = sample_weights[:-val_count], sample_weights[-val_count:]
else:
train_x, train_y = x, y
train_weights = sample_weights
val_x, val_y = None, None
val_weights = None
train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y, train_weights))
train_ds = train_ds.shuffle(min(len(train_x), shuffle_buffer), reshuffle_each_iteration=True)
train_ds = train_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_ds = None
if val_x is not None:
val_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y, val_weights))
val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
return train_ds, val_ds, len(train_x), 0 if val_x is None else len(val_x)
def model_paths(model_dir):
path = Path(model_dir)
return path / "model.keras", path / "tokenizer.json", path / "config.json"
def save_artifacts(model, tokenizer, model_dir, config):
path = Path(model_dir)
path.mkdir(parents=True, exist_ok=True)
model_path, tokenizer_path, config_path = model_paths(path)
model.save(model_path)
tokenizer_path.write_text(json.dumps(tokenizer.to_dict(), ensure_ascii=False, indent=2), encoding="utf-8")
config_path.write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8")
def load_artifacts(model_dir="tg-medium"):
model_path, tokenizer_path, config_path = model_paths(model_dir)
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
if not tokenizer_path.exists():
raise FileNotFoundError(f"Tokenizer file not found: {tokenizer_path}")
model = tf.keras.models.load_model(
model_path,
custom_objects={
"TokenAndPositionEmbedding": TokenAndPositionEmbedding,
"TransformerBlock": TransformerBlock,
},
compile=False,
)
tokenizer = ThreeChunkTokenizer.from_dict(json.loads(tokenizer_path.read_text(encoding="utf-8")))
config = json.loads(config_path.read_text(encoding="utf-8")) if config_path.exists() else {}
return model, tokenizer, config
def clean_answer_text(text):
text = text.strip()
for marker in (QUESTION_LABEL, "Prompt:"):
marker_index = text.find(marker)
if marker_index >= 0:
text = text[:marker_index].strip()
if text.startswith(ANSWER_LABEL):
text = text[len(ANSWER_LABEL) :].strip()
return text.strip()
def generate_text(model, tokenizer, prompt, maxlen=100, max_new_tokens=80, temperature=1.0, qa_mode=True):
if qa_mode:
prompt = format_question_prompt(prompt)
context = [tokenizer.token_to_id[START_TOKEN]]
context.extend(tokenizer.encode(prompt, add_boundaries=False))
end_id = tokenizer.token_to_id[END_TOKEN]
pad_id = tokenizer.token_to_id[PAD_TOKEN]
generated = []
for _ in range(max_new_tokens):
input_ids = context[-maxlen:]
last_position = len(input_ids) - 1
padded = input_ids + [pad_id] * (maxlen - len(input_ids))
predictions = model.predict(np.asarray([padded], dtype="int32"), verbose=0)[0, last_position]
if temperature <= 0:
next_id = int(np.argmax(predictions))
else:
logits = np.log(np.maximum(predictions, 1e-9)) / temperature
probabilities = np.exp(logits) / np.sum(np.exp(logits))
next_id = int(np.random.choice(len(probabilities), p=probabilities))
if next_id == end_id:
break
context.append(next_id)
generated.append(next_id)
output = tokenizer.decode_ids(generated)
return clean_answer_text(output) if qa_mode else output
def train_command(args):
tf_settings = configure_tensorflow(
use_mixed_precision=args.mixed_precision,
use_xla=args.xla,
)
texts = read_training_texts(args.data)
tokenizer = ThreeChunkTokenizer(max_vocab_size=args.max_vocab_size).fit(texts)
x, y, sample_weights = make_training_arrays(
texts,
tokenizer,
args.maxlen,
stride=args.stride,
max_sequences=args.max_sequences,
answer_only_loss=args.answer_only_loss,
)
train_ds, val_ds, train_count, val_count = make_tf_datasets(
x,
y,
sample_weights,
batch_size=args.batch_size,
validation_split=args.validation_split,
)
model = build_model(
vocab_size=len(tokenizer.vocab),
maxlen=args.maxlen,
embed_dim=args.embed_dim,
num_heads=args.num_heads,
ff_dim=args.ff_dim,
dropout=args.dropout,
num_layers=args.num_layers,
jit_compile=args.xla,
)
print(
f"Training on {train_count:,} sequences"
f"{f' with {val_count:,} validation sequences' if val_count else ''}."
)
print(
f"Vocab: {len(tokenizer.vocab):,} | Batch: {args.batch_size} | "
f"Stride: {args.stride} | Blocks: {args.num_layers} | GPUs: {tf_settings['gpu_count']} | "
f"Answer-only loss: {args.answer_only_loss} | "
f"Memory growth: {tf_settings['memory_growth']} | "
f"Mixed precision: {tf_settings['mixed_precision']} | XLA: {tf_settings['xla']}"
)
config = {
"maxlen": args.maxlen,
"embed_dim": args.embed_dim,
"num_heads": args.num_heads,
"num_layers": args.num_layers,
"ff_dim": args.ff_dim,
"dropout": args.dropout,
"vocab_size": len(tokenizer.vocab),
"tokenizer_chunk_size": tokenizer.chunk_size,
"tokenizer_word_chunk_size": tokenizer.word_chunk_size,
"tokenizer_cyrillic_chunk_size": tokenizer.cyrillic_chunk_size,
"tokenizer_version": tokenizer.tokenizer_version,
"stride": args.stride,
"max_sequences": args.max_sequences,
"answer_only_loss": args.answer_only_loss,
"mixed_precision": tf_settings["mixed_precision"],
"xla": tf_settings["xla"],
}
model.fit(train_ds, epochs=args.epochs, validation_data=val_ds)
save_artifacts(model, tokenizer, args.model_dir, config)
print(f"Saved model, tokenizer, and config to {args.model_dir}")
def chat_command(args):
model, tokenizer, config = load_artifacts(args.model_dir)
maxlen = int(config.get("maxlen", args.maxlen))
print("Chat ready. Type /exit or /quit to stop.")
while True:
try:
prompt = input("you> ").strip()
except (EOFError, KeyboardInterrupt):
print()
break
if prompt.lower() in {"/exit", "/quit"}:
break
response = generate_text(model, tokenizer, prompt, maxlen=maxlen, max_new_tokens=args.max_new_tokens, temperature=args.temperature)
print(f"bot> {response}")
def generate_command(args):
model, tokenizer, config = load_artifacts(args.model_dir)
maxlen = int(config.get("maxlen", args.maxlen))
print(generate_text(model, tokenizer, args.prompt, maxlen=maxlen, max_new_tokens=args.max_new_tokens, temperature=args.temperature))
def save_command(args):
model, tokenizer, config = load_artifacts(args.model_dir)
save_artifacts(model, tokenizer, args.output_dir, config)
print(f"Saved copy to {args.output_dir}")
def make_notebook_cell(cell_type, source):
return {
"cell_type": cell_type,
"metadata": {},
"source": [line + "\n" for line in source.splitlines()],
}
def build_colab_notebook(dataset_id=TG_DATASET_PRO_ID):
app_source = Path(__file__).read_text(encoding="utf-8")
train_code = f"""
DATASET_ID = {dataset_id!r}
SPLIT = "train"
MODEL_DIR = "/content/tg-dataset-pro-model"
# 0 means use the full dataset. Set a smaller number for a quick smoke test.
MAX_ROWS = 0
# These defaults are intentionally Colab-friendly. Increase MAX_SEQUENCES,
# EMBED_DIM, NUM_LAYERS, or EPOCHS if you have a stronger GPU/runtime.
EPOCHS = 3
BATCH_SIZE = 128
VALIDATION_SPLIT = 0.02
MAXLEN = 128
STRIDE = 64
MAX_VOCAB_SIZE = 32000
MAX_SEQUENCES = 120000
EMBED_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 2
FF_DIM = 256
DROPOUT = 0.15
ANSWER_ONLY_LOSS = True
from datasets import load_dataset
from app import (
ThreeChunkTokenizer,
build_model,
configure_tensorflow,
make_tf_datasets,
make_training_arrays,
record_to_text,
save_artifacts,
)
tf_settings = configure_tensorflow(use_mixed_precision=True, use_xla=True)
print("TensorFlow settings:", tf_settings)
raw_dataset = load_dataset(DATASET_ID, split=SPLIT)
texts = []
for index, record in enumerate(raw_dataset):
if MAX_ROWS and index >= MAX_ROWS:
break
text = record_to_text(record)
if text and text.strip():
texts.append(text)
if not texts:
raise ValueError("No text rows were found in the Hugging Face dataset.")
print(f"Loaded {{len(texts):,}} training rows from {{DATASET_ID}}.")
tokenizer = ThreeChunkTokenizer(max_vocab_size=MAX_VOCAB_SIZE).fit(texts)
x, y, sample_weights = make_training_arrays(
texts,
tokenizer,
MAXLEN,
stride=STRIDE,
max_sequences=MAX_SEQUENCES,
answer_only_loss=ANSWER_ONLY_LOSS,
)
train_ds, val_ds, train_count, val_count = make_tf_datasets(
x,
y,
sample_weights,
batch_size=BATCH_SIZE,
validation_split=VALIDATION_SPLIT,
)
model = build_model(
vocab_size=len(tokenizer.vocab),
maxlen=MAXLEN,
embed_dim=EMBED_DIM,
num_heads=NUM_HEADS,
ff_dim=FF_DIM,
dropout=DROPOUT,
num_layers=NUM_LAYERS,
jit_compile=True,
)
print(
f"Training on {{train_count:,}} sequences"
f"{{f' with {{val_count:,}} validation sequences' if val_count else ''}}."
)
print(f"Vocab: {{len(tokenizer.vocab):,}} | Maxlen: {{MAXLEN}} | Batch: {{BATCH_SIZE}}")
model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
config = {{
"dataset_id": DATASET_ID,
"dataset_split": SPLIT,
"max_rows": MAX_ROWS,
"maxlen": MAXLEN,
"embed_dim": EMBED_DIM,
"num_heads": NUM_HEADS,
"num_layers": NUM_LAYERS,
"ff_dim": FF_DIM,
"dropout": DROPOUT,
"vocab_size": len(tokenizer.vocab),
"tokenizer_chunk_size": tokenizer.chunk_size,
"tokenizer_word_chunk_size": tokenizer.word_chunk_size,
"tokenizer_cyrillic_chunk_size": tokenizer.cyrillic_chunk_size,
"tokenizer_version": tokenizer.tokenizer_version,
"stride": STRIDE,
"max_sequences": MAX_SEQUENCES,
"answer_only_loss": ANSWER_ONLY_LOSS,
"mixed_precision": tf_settings["mixed_precision"],
"xla": tf_settings["xla"],
}}
save_artifacts(model, tokenizer, MODEL_DIR, config)
print(f"Saved model, tokenizer, and config to {{MODEL_DIR}}")
"""
zip_code = """
import shutil
from google.colab import files
archive_path = shutil.make_archive(MODEL_DIR, "zip", MODEL_DIR)
print(f"Created archive: {archive_path}")
files.download(archive_path)
"""
return {
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"accelerator": "GPU",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3",
},
"language_info": {"name": "python"},
"colab": {"gpuType": "T4"},
},
"cells": [
make_notebook_cell(
"markdown",
f"# Train the tiny TG model on `{dataset_id}`\n\n"
"Run this notebook in Google Colab with a GPU runtime. "
"It downloads the Hugging Face dataset inside Colab, trains the model, "
"and saves `model.keras`, `tokenizer.json`, and `config.json`.",
),
make_notebook_cell("code", "%pip -q install datasets huggingface_hub"),
make_notebook_cell(
"code",
"from pathlib import Path\n"
f"Path('app.py').write_text({app_source!r}, encoding='utf-8')\n"
"print('Wrote app.py into the Colab runtime.')",
),
make_notebook_cell("code", train_code.strip()),
make_notebook_cell(
"markdown",
"Run the next cell after training if you want to download the saved artifacts as a zip.",
),
make_notebook_cell("code", zip_code.strip()),
],
}
def write_colab_notebook(output_path=DEFAULT_COLAB_NOTEBOOK, dataset_id=TG_DATASET_PRO_ID):
notebook = build_colab_notebook(dataset_id=dataset_id)
path = Path(output_path)
path.write_text(json.dumps(notebook, ensure_ascii=False, indent=2), encoding="utf-8")
return path
def colab_command(args):
notebook_path = write_colab_notebook(args.output, dataset_id=args.dataset_id)
print(f"Created Colab notebook: {notebook_path}")
print("Upload it to Google Colab, choose a GPU runtime, then run all cells.")
def ask_text(prompt, default=None, required=False):
while True:
suffix = f" [{default}]" if default is not None else ""
value = input(f"{prompt}{suffix}: ").strip()
if value:
return value
if default is not None:
return str(default)
if not required:
return ""
print("Please enter a value.")
def ask_int(prompt, default):
while True:
value = ask_text(prompt, default)
try:
return int(value)
except ValueError:
print("Please enter a whole number.")
def ask_float(prompt, default):
while True:
value = ask_text(prompt, default)
try:
return float(value)
except ValueError:
print("Please enter a number.")
def ask_bool(prompt, default=True):
default_text = "y" if default else "n"
while True:
value = ask_text(prompt, default_text).lower()
if value in {"y", "yes", "true", "1"}:
return True
if value in {"n", "no", "false", "0"}:
return False
print("Please enter y or n.")
class MenuArgs:
pass
def build_train_args_from_input():
args = MenuArgs()
data = ask_text("Dataset path(s), separated by commas", required=True)
args.data = [part.strip() for part in data.split(",") if part.strip()]
try:
defaults = estimate_dataset_defaults(args.data)
print("\nDataset estimate:")
print(f"Records: {defaults['records']:,}")
print(f"Tokens: {defaults['tokens']:,}")
print(f"Unique tokenizer tokens: {defaults['unique_tokens']:,}")
print("Recommended fast settings for this dataset are pre-filled below.")
except Exception as exc:
print(f"Could not estimate dataset defaults, using generic fast defaults: {exc}")
defaults = {
"epochs": 3,
"batch_size": 128,
"validation_split": 0.02,
"maxlen": 96,
"stride": 48,
"max_vocab_size": 32000,
"max_sequences": 60000,
"embed_dim": 96,
"num_heads": 4,
"num_layers": 2,
"ff_dim": 192,
"dropout": 0.15,
"answer_only_loss": True,
}
args.model_dir = ask_text("Model save directory", "tg-medium")
args.epochs = ask_int("Epochs", defaults["epochs"])
args.batch_size = ask_int("Batch size", defaults["batch_size"])
args.validation_split = ask_float("Validation split", defaults["validation_split"])
args.maxlen = ask_int("Context length / max sequence length", defaults["maxlen"])
args.stride = ask_int("Training stride, higher is faster", defaults["stride"])
args.max_vocab_size = ask_int("Max vocab size", defaults["max_vocab_size"])
args.max_sequences = ask_int("Max training sequences, 0 = all", defaults["max_sequences"])
args.embed_dim = ask_int("Embedding size", defaults["embed_dim"])
args.num_heads = ask_int("Attention heads", defaults["num_heads"])
args.num_layers = ask_int("Transformer blocks", defaults["num_layers"])
args.ff_dim = ask_int("Feed-forward size", defaults["ff_dim"])
args.dropout = ask_float("Dropout", defaults["dropout"])
args.answer_only_loss = ask_bool("Train loss only on answer tokens", defaults["answer_only_loss"])
args.mixed_precision = ask_bool("Use mixed precision on GPU", True)
args.xla = ask_bool("Use XLA/JIT compile", True)
return args
def build_chat_args_from_input():
args = MenuArgs()
args.model_dir = ask_text("Model directory to load", "tg-medium")
args.maxlen = ask_int("Fallback context length if config is missing", 100)
args.max_new_tokens = ask_int("Max new tokens per reply", 80)
args.temperature = ask_float("Temperature, 0 for greedy", 0.8)
return args
def build_generate_args_from_input():
args = build_chat_args_from_input()
args.prompt = ask_text("Prompt", required=True)
return args
def build_save_args_from_input():
args = MenuArgs()
args.model_dir = ask_text("Model directory to load", "tg-medium")
args.output_dir = ask_text("Output directory for saved copy", required=True)
return args
def build_colab_args_from_input():
args = MenuArgs()
args.dataset_id = ask_text("Hugging Face dataset ID", TG_DATASET_PRO_ID)
args.output = ask_text("Output notebook path", DEFAULT_COLAB_NOTEBOOK)
return args
def print_menu():
print("\nSmall LM CLI")
print("1. Train model")
print("2. Chat with model")
print("3. Generate one response")
print("4. Save/copy loaded model artifacts")
print("5. Create Colab auto-train notebook")
print("6. Exit")
def main():
while True:
print_menu()
choice = input("Choose an option: ").strip().lower()
try:
if choice in {"1", "train"}:
train_command(build_train_args_from_input())
elif choice in {"2", "chat"}:
chat_command(build_chat_args_from_input())
elif choice in {"3", "generate"}:
generate_command(build_generate_args_from_input())
elif choice in {"4", "save", "copy"}:
save_command(build_save_args_from_input())
elif choice in {"5", "colab"}:
colab_command(build_colab_args_from_input())
elif choice in {"6", "exit", "quit", "q"}:
break
else:
print("Please choose 1, 2, 3, 4, 5, or 6.")
except Exception as exc:
print(f"Error: {exc}")
if __name__ == "__main__":
main()