import keras from keras import ops # 按照最大概率采样 def greedy_search(preds): return ops.argmax(preds) # 温度采样 def random_sample(preds, temperature=1.0): preds = preds / temperature return keras.random.categorical(preds[None, :], num_samples=1)[0] # 只从前 k 个元素中采用温度采样 def top_k(preds, k=5, temperature=1.0): preds = preds / temperature top_preds, top_indices = ops.top_k(preds, k=k, sorted=False) choice = keras.random.categorical(top_preds[None, :], num_samples=1)[0] return ops.take_along_axis(top_indices, choice, axis=-1)