|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
def blend_generate(prompt, wa, wb): |
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_a = model_a(input_ids) |
|
|
output_b = model_b(input_ids) |
|
|
|
|
|
logits_a = output_a.logits[:, -1, :] |
|
|
logits_b = output_b.logits[:, -1, :] |
|
|
|
|
|
|
|
|
blended_logits = wa * logits_a + wb * logits_b |
|
|
|
|
|
|
|
|
probs = torch.softmax(blended_logits, dim=-1) |
|
|
|
|
|
|
|
|
token = torch.multinomial(probs, 1) |
|
|
next_token_id = token.item() |
|
|
next_token = tokenizer.decode([next_token_id]) |
|
|
|
|
|
return next_token |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## Blended Prompt Chat (TinyLlama)") |
|
|
sysA = gr.Textbox(label="System Prompt A", value="You are assistant A.") |
|
|
sysB = gr.Textbox(label="System Prompt B", value="You are assistant B.") |
|
|
wA = gr.Slider(-5, 5, value=1.0, step=0.1, label="Weight A") |
|
|
wB = gr.Slider(-5, 5, value=1.0, step=0.1, label="Weight B") |
|
|
user_msg = gr.Textbox(label="User Message") |
|
|
temp = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature") |
|
|
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") |
|
|
max_tokens = gr.Slider(1, 200, value=100, step=1, label="Max New Tokens") |
|
|
output = gr.Textbox(label="Response") |
|
|
|
|
|
btn = gr.Button("Generate") |
|
|
btn.click( |
|
|
blend_generate, |
|
|
[sysA, sysB, wA, wB, user_msg, max_tokens, temp, top_p], |
|
|
output, |
|
|
show_progress=True, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|