general-deep-learning / pipeline /base /sample_functions.py
yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
import keras
from keras import ops
# 按照最大概率采样
def greedy_search(preds):
return ops.argmax(preds)
# 温度采样
def random_sample(preds, temperature=1.0):
preds = preds / temperature
return keras.random.categorical(preds[None, :], num_samples=1)[0]
# 只从前 k 个元素中采用温度采样
def top_k(preds, k=5, temperature=1.0):
preds = preds / temperature
top_preds, top_indices = ops.top_k(preds, k=k, sorted=False)
choice = keras.random.categorical(top_preds[None, :], num_samples=1)[0]
return ops.take_along_axis(top_indices, choice, axis=-1)