File size: 2,325 Bytes
4aadb97 8441218 4eb81a9 8441218 4eb81a9 75fbb74 8441218 54d1587 8441218 4aadb97 75fbb74 4aadb97 54d1587 8441218 4aadb97 75fbb74 8441218 75fbb74 4aadb97 54d1587 4aadb97 8441218 75fbb74 8441218 75fbb74 8441218 4aadb97 54d1587 4aadb97 8441218 54d1587 8441218 4aadb97 75fbb74 4aadb97 54d1587 4aadb97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM
MODEL_ID = "vietrix/viena-60m"
tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID, legacy=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
device_map="cpu",
)
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
hf_token: gr.OAuthToken,
):
messages = [{"role": "system", "content": system_message}]
messages.extend(history)
messages.append({"role": "user", "content": message})
if getattr(tokenizer, "chat_template", None):
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# fallback rất đơn giản
parts = []
for m in messages:
parts.append(f"{m['role'].upper()}: {m['content']}")
parts.append("ASSISTANT:")
prompt = "\n".join(parts)
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=int(max_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
repetition_penalty=1.15,
no_repeat_ngram_size=4,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
gen_ids = outputs[0, inputs.input_ids.shape[1]:]
text = tokenizer.decode(gen_ids, skip_special_tokens=True)
resp = ""
for ch in text:
resp += ch
yield resp
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
with gr.Blocks() as demo:
with gr.Sidebar():
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch()
|