| import gradio as gr | |
| from transformers import pipeline, AutoTokenizer | |
| model_name = "jodiox/olmo3-190m-zh-nano-sft" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # 收集结束标记 id | |
| eos_token_ids = [] | |
| if tokenizer.eos_token_id is not None: | |
| eos_token_ids.append(tokenizer.eos_token_id) | |
| im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
| if im_end_id is not None and im_end_id != tokenizer.unk_token_id: | |
| eos_token_ids.append(im_end_id) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model_name, | |
| tokenizer=tokenizer, | |
| model_kwargs={"attn_implementation": "sdpa"}, | |
| ) | |
| def predict(message): | |
| messages = [{"role": "user", "content": message}] | |
| output = pipe( | |
| messages, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| return_full_text=False, | |
| eos_token_id=eos_token_ids, # 关键:告诉 pipeline 遇到这些 token 就停 | |
| ) | |
| return output[0]["generated_text"] | |
| gr.Interface(fn=predict, inputs="text", outputs="text").launch() |