import torch from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name_a = "mistralai/Mistral-7B-v0.1" model_name_b = "mistralai/Mistral-7B-v0.1" # you can replace this with a second different model or finetuned variant print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name_a) print("Loading models...") model_a = AutoModelForCausalLM.from_pretrained(model_name_a, device_map="auto", torch_dtype=torch.float16) model_b = AutoModelForCausalLM.from_pretrained(model_name_b, device_map="auto", torch_dtype=torch.float16) model_a.eval() model_b.eval() def blend_generate(system_prompt_a, system_prompt_b, user_prompt, wa, wb, max_length=50, temperature=0.7, top_k=50): device = next(model_a.parameters()).device generated_text = user_prompt for _ in range(max_length): prompt_a = system_prompt_a.strip() + "\n" + generated_text prompt_b = system_prompt_b.strip() + "\n" + generated_text input_ids_a = tokenizer(prompt_a, return_tensors="pt").input_ids.to(device) input_ids_b = tokenizer(prompt_b, return_tensors="pt").input_ids.to(device) with torch.no_grad(): output_a = model_a(input_ids_a) output_b = model_b(input_ids_b) logits_a = output_a.logits[:, -1, :] logits_b = output_b.logits[:, -1, :] blended_logits = wa * logits_a + wb * logits_b # Apply top-k filtering top_k_logits, top_k_indices = torch.topk(blended_logits, top_k) filtered_logits = torch.full_like(blended_logits, float('-inf')) filtered_logits.scatter_(1, top_k_indices, top_k_logits) # Temperature scaling scaled_logits = filtered_logits / temperature probs = torch.softmax(scaled_logits, dim=-1) next_token = torch.multinomial(probs, 1).item() if next_token == tokenizer.eos_token_id: break next_token_str = tokenizer.decode([next_token]) generated_text += next_token_str return generated_text with gr.Blocks() as demo: system_prompt_a = gr.Textbox(label="System Prompt A", value="You are a helpful assistant.") system_prompt_b = gr.Textbox(label="System Prompt B", value="You are a witty assistant.") user_prompt = gr.Textbox(label="User Prompt", value="Tell me a story about a dragon.") weight_a = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model A") weight_b = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model B") output_text = gr.Textbox(label="Output", lines=10) btn = gr.Button("Generate") btn.click(blend_generate, inputs=[system_prompt_a, system_prompt_b, user_prompt, weight_a, weight_b], outputs=output_text) demo.launch()