| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| model.eval() | |
| def blend_generate(sysA, sysB, wA, wB, user_message, max_new_tokens, temperature, top_p): | |
| promptA = f"<|system|>{sysA}\n<|user|>{user_message}\n<|assistant|>" | |
| promptB = f"<|system|>{sysB}\n<|user|>{user_message}\n<|assistant|>" | |
| idsA = tokenizer(promptA, return_tensors="pt").input_ids.to(model.device) | |
| idsB = tokenizer(promptB, return_tensors="pt").input_ids.to(model.device) | |
| outA, outB = idsA.clone(), idsB.clone() | |
| response = "" | |
| for _ in range(max_new_tokens): | |
| with torch.no_grad(): | |
| logitsA = model(input_ids=outA).logits[:, -1, :] | |
| logitsB = model(input_ids=outB).logits[:, -1, :] | |
| blended = wA * logitsA + wB * logitsB | |
| blended = blended / temperature | |
| probs = F.softmax(blended, dim=-1) | |
| sorted_probs, sorted_idx = torch.sort(probs, descending=True) | |
| cum = torch.cumsum(sorted_probs, dim=-1) | |
| sorted_probs[cum > top_p] = 0 | |
| sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True) | |
| token = sorted_idx[:, torch.multinomial(sorted_probs, 1)].squeeze() | |
| outA = torch.cat([outA, token.unsqueeze(0).unsqueeze(0)], dim=1) | |
| outB = torch.cat([outB, token.unsqueeze(0).unsqueeze(0)], dim=1) | |
| token_str = tokenizer.decode(token) | |
| response += token_str | |
| yield response | |
| if token.item() == tokenizer.eos_token_id: | |
| break | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Blended Prompt Chat (TinyLlama)") | |
| sysA = gr.Textbox(label="System Prompt A", value="You are assistant A.") | |
| sysB = gr.Textbox(label="System Prompt B", value="You are assistant B.") | |
| wA = gr.Slider(-5, 5, value=1.0, step=0.1, label="Weight A") | |
| wB = gr.Slider(-5, 5, value=1.0, step=0.1, label="Weight B") | |
| user_msg = gr.Textbox(label="User Message") | |
| temp = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| max_tokens = gr.Slider(1, 200, value=100, step=1, label="Max New Tokens") | |
| output = gr.Textbox(label="Response") | |
| btn = gr.Button("Generate") | |
| btn.click( | |
| blend_generate, | |
| [sysA, sysB, wA, wB, user_msg, max_tokens, temp, top_p], | |
| output, | |
| show_progress=True, | |
| stream=True | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |