import gradio as gr import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) model.eval() def blend_generate(prompt, wa, wb): input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) with torch.no_grad(): output_a = model_a(input_ids) output_b = model_b(input_ids) logits_a = output_a.logits[:, -1, :] logits_b = output_b.logits[:, -1, :] # Weighted sum of raw logits (before softmax) blended_logits = wa * logits_a + wb * logits_b # Apply softmax safely to get valid probability distribution probs = torch.softmax(blended_logits, dim=-1) # Sample token from valid probability distribution token = torch.multinomial(probs, 1) next_token_id = token.item() next_token = tokenizer.decode([next_token_id]) return next_token with gr.Blocks() as demo: gr.Markdown("## Blended Prompt Chat (TinyLlama)") sysA = gr.Textbox(label="System Prompt A", value="You are assistant A.") sysB = gr.Textbox(label="System Prompt B", value="You are assistant B.") wA = gr.Slider(-5, 5, value=1.0, step=0.1, label="Weight A") wB = gr.Slider(-5, 5, value=1.0, step=0.1, label="Weight B") user_msg = gr.Textbox(label="User Message") temp = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") max_tokens = gr.Slider(1, 200, value=100, step=1, label="Max New Tokens") output = gr.Textbox(label="Response") btn = gr.Button("Generate") btn.click( blend_generate, [sysA, sysB, wA, wB, user_msg, max_tokens, temp, top_p], output, show_progress=True, ) if __name__ == "__main__": demo.launch()