HawkGPT-v0.5 / generate.py
HawkLabofficial's picture
Upload generate.py with huggingface_hub
72a5fe2 verified
Raw
History Blame Contribute Delete
3.16 kB
"""HawkGPT 0.5 — Text generation."""
import os, sys
from gpu_setup import ensure_gpu
ensure_gpu()
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import numpy as np
import tensorflow as tf
import argparse
import config
from tokenizer_module import load_tokenizer
from model import build_model
def load_trained_model(vocab_size, checkpoint_path):
model = build_model(vocab_size)
model.load_weights(checkpoint_path)
print(f"Loaded weights from {checkpoint_path}")
return model
@tf.function(reduce_retracing=True)
def generate_step(model, token_ids, temperature, top_k):
logits = model(token_ids, training=False)[:, -1, :] / temperature
if top_k > 0:
top_k_vals, _ = tf.math.top_k(logits, k=top_k)
logits = tf.where(logits < top_k_vals[:, -1:], -1e9, logits)
probs = tf.nn.softmax(logits, axis=-1)
return tf.random.categorical(tf.math.log(probs), num_samples=1)
def generate(model, tokenizer, prompt, max_new_tokens=200, temperature=0.8, top_k=50, num_return=1):
pad_id = tokenizer.token_to_id("[PAD]")
eos_id = tokenizer.token_to_id("[EOS]")
bos_id = tokenizer.token_to_id("[BOS]")
tokenizer.no_padding()
tokenizer.no_truncation()
enc = tokenizer.encode(prompt)
prompt_ids = [bos_id] + enc.ids
if prompt_ids[-1] == eos_id:
prompt_ids = prompt_ids[:-1]
results = []
for _ in range(num_return):
generated = list(prompt_ids)
for _ in range(max_new_tokens):
ctx = generated[-config.MAX_SEQ_LEN:]
token_tensor = tf.constant([ctx], dtype=tf.int32)
next_token = generate_step(model, token_tensor, temperature, top_k)
next_id = next_token.numpy()[0, 0]
if next_id in (eos_id, pad_id):
break
generated.append(next_id)
new_ids = generated[len(prompt_ids):]
text = tokenizer.decode(new_ids)
results.append(prompt + text)
return results
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="Вопрос: Привет!")
parser.add_argument("--checkpoint", type=str, default=os.path.join(config.CHECKPOINT_DIR, "model_best.weights.h5"))
parser.add_argument("--max_tokens", type=int, default=80)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--num_return", type=int, default=5)
args = parser.parse_args()
tokenizer = load_tokenizer()
model = load_trained_model(tokenizer.get_vocab_size(), args.checkpoint)
print(f"\nPrompt: {args.prompt}")
print(f"Temp: {args.temperature} | Top-K: {args.top_k} | Tokens: {args.max_tokens}")
print("=" * 60)
outputs = generate(model, tokenizer, prompt=args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
num_return=args.num_return)
for i, text in enumerate(outputs):
print(f"\n--- Sample {i+1} ---")
print(text)
if __name__ == "__main__":
main()