|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import gradio as gr |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model_name_a = "mistralai/Mistral-7B-v0.1" |
|
|
model_name_b = "mistralai/Mistral-7B-v0.1" |
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_a) |
|
|
|
|
|
print("Loading models...") |
|
|
model_a = AutoModelForCausalLM.from_pretrained(model_name_a, device_map="auto", torch_dtype=torch.float16) |
|
|
model_b = AutoModelForCausalLM.from_pretrained(model_name_b, device_map="auto", torch_dtype=torch.float16) |
|
|
|
|
|
model_a.eval() |
|
|
model_b.eval() |
|
|
|
|
|
def blend_generate(system_prompt_a, system_prompt_b, user_prompt, wa, wb, max_length=50, temperature=0.7, top_k=50): |
|
|
device = next(model_a.parameters()).device |
|
|
generated_text = user_prompt |
|
|
|
|
|
for _ in range(max_length): |
|
|
prompt_a = system_prompt_a.strip() + "\n" + generated_text |
|
|
prompt_b = system_prompt_b.strip() + "\n" + generated_text |
|
|
|
|
|
input_ids_a = tokenizer(prompt_a, return_tensors="pt").input_ids.to(device) |
|
|
input_ids_b = tokenizer(prompt_b, return_tensors="pt").input_ids.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_a = model_a(input_ids_a) |
|
|
output_b = model_b(input_ids_b) |
|
|
|
|
|
logits_a = output_a.logits[:, -1, :] |
|
|
logits_b = output_b.logits[:, -1, :] |
|
|
|
|
|
blended_logits = wa * logits_a + wb * logits_b |
|
|
|
|
|
|
|
|
top_k_logits, top_k_indices = torch.topk(blended_logits, top_k) |
|
|
filtered_logits = torch.full_like(blended_logits, float('-inf')) |
|
|
filtered_logits.scatter_(1, top_k_indices, top_k_logits) |
|
|
|
|
|
|
|
|
scaled_logits = filtered_logits / temperature |
|
|
|
|
|
probs = torch.softmax(scaled_logits, dim=-1) |
|
|
|
|
|
next_token = torch.multinomial(probs, 1).item() |
|
|
|
|
|
if next_token == tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
next_token_str = tokenizer.decode([next_token]) |
|
|
generated_text += next_token_str |
|
|
|
|
|
return generated_text |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
system_prompt_a = gr.Textbox(label="System Prompt A", value="You are a helpful assistant.") |
|
|
system_prompt_b = gr.Textbox(label="System Prompt B", value="You are a witty assistant.") |
|
|
user_prompt = gr.Textbox(label="User Prompt", value="Tell me a story about a dragon.") |
|
|
weight_a = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model A") |
|
|
weight_b = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model B") |
|
|
output_text = gr.Textbox(label="Output", lines=10) |
|
|
|
|
|
btn = gr.Button("Generate") |
|
|
btn.click(blend_generate, inputs=[system_prompt_a, system_prompt_b, user_prompt, weight_a, weight_b], outputs=output_text) |
|
|
|
|
|
demo.launch() |
|
|
|