SFT / app.py
jodiox's picture
update model to jodiox/olmo3-190m-zh-nano-sft
2f62c52 verified
Raw
History Blame Contribute Delete
1.09 kB
import gradio as gr
from transformers import pipeline, AutoTokenizer
model_name = "jodiox/olmo3-190m-zh-nano-sft"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 收集结束标记 id
eos_token_ids = []
if tokenizer.eos_token_id is not None:
eos_token_ids.append(tokenizer.eos_token_id)
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if im_end_id is not None and im_end_id != tokenizer.unk_token_id:
eos_token_ids.append(im_end_id)
pipe = pipeline(
"text-generation",
model=model_name,
tokenizer=tokenizer,
model_kwargs={"attn_implementation": "sdpa"},
)
def predict(message):
messages = [{"role": "user", "content": message}]
output = pipe(
messages,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
return_full_text=False,
eos_token_id=eos_token_ids, # 关键:告诉 pipeline 遇到这些 token 就停
)
return output[0]["generated_text"]
gr.Interface(fn=predict, inputs="text", outputs="text").launch()