"""HawkGPT 0.5 — Text generation.""" import os, sys from gpu_setup import ensure_gpu ensure_gpu() os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" import numpy as np import tensorflow as tf import argparse import config from tokenizer_module import load_tokenizer from model import build_model def load_trained_model(vocab_size, checkpoint_path): model = build_model(vocab_size) model.load_weights(checkpoint_path) print(f"Loaded weights from {checkpoint_path}") return model @tf.function(reduce_retracing=True) def generate_step(model, token_ids, temperature, top_k): logits = model(token_ids, training=False)[:, -1, :] / temperature if top_k > 0: top_k_vals, _ = tf.math.top_k(logits, k=top_k) logits = tf.where(logits < top_k_vals[:, -1:], -1e9, logits) probs = tf.nn.softmax(logits, axis=-1) return tf.random.categorical(tf.math.log(probs), num_samples=1) def generate(model, tokenizer, prompt, max_new_tokens=200, temperature=0.8, top_k=50, num_return=1): pad_id = tokenizer.token_to_id("[PAD]") eos_id = tokenizer.token_to_id("[EOS]") bos_id = tokenizer.token_to_id("[BOS]") tokenizer.no_padding() tokenizer.no_truncation() enc = tokenizer.encode(prompt) prompt_ids = [bos_id] + enc.ids if prompt_ids[-1] == eos_id: prompt_ids = prompt_ids[:-1] results = [] for _ in range(num_return): generated = list(prompt_ids) for _ in range(max_new_tokens): ctx = generated[-config.MAX_SEQ_LEN:] token_tensor = tf.constant([ctx], dtype=tf.int32) next_token = generate_step(model, token_tensor, temperature, top_k) next_id = next_token.numpy()[0, 0] if next_id in (eos_id, pad_id): break generated.append(next_id) new_ids = generated[len(prompt_ids):] text = tokenizer.decode(new_ids) results.append(prompt + text) return results def main(): parser = argparse.ArgumentParser() parser.add_argument("--prompt", type=str, default="Вопрос: Привет!") parser.add_argument("--checkpoint", type=str, default=os.path.join(config.CHECKPOINT_DIR, "model_best.weights.h5")) parser.add_argument("--max_tokens", type=int, default=80) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--top_k", type=int, default=50) parser.add_argument("--num_return", type=int, default=5) args = parser.parse_args() tokenizer = load_tokenizer() model = load_trained_model(tokenizer.get_vocab_size(), args.checkpoint) print(f"\nPrompt: {args.prompt}") print(f"Temp: {args.temperature} | Top-K: {args.top_k} | Tokens: {args.max_tokens}") print("=" * 60) outputs = generate(model, tokenizer, prompt=args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, num_return=args.num_return) for i, text in enumerate(outputs): print(f"\n--- Sample {i+1} ---") print(text) if __name__ == "__main__": main()