File size: 2,822 Bytes
2ef1e0a 82ff832 1782685 c928282 1782685 82ff832 1782685 2ef1e0a 82ff832 1782685 2ef1e0a 82ff832 51d9d55 38b5252 51d9d55 82ff832 142eb42 51d9d55 142eb42 51d9d55 142eb42 51d9d55 142eb42 51d9d55 142eb42 82ff832 51d9d55 82ff832 51d9d55 82ff832 51d9d55 38b5252 82ff832 f45c0a2 82ff832 f45c0a2 38b5252 1782685 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name_a = "mistralai/Mistral-7B-v0.1"
model_name_b = "mistralai/Mistral-7B-v0.1" # you can replace this with a second different model or finetuned variant
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name_a)
print("Loading models...")
model_a = AutoModelForCausalLM.from_pretrained(model_name_a, device_map="auto", torch_dtype=torch.float16)
model_b = AutoModelForCausalLM.from_pretrained(model_name_b, device_map="auto", torch_dtype=torch.float16)
model_a.eval()
model_b.eval()
def blend_generate(system_prompt_a, system_prompt_b, user_prompt, wa, wb, max_length=50, temperature=0.7, top_k=50):
device = next(model_a.parameters()).device
generated_text = user_prompt
for _ in range(max_length):
prompt_a = system_prompt_a.strip() + "\n" + generated_text
prompt_b = system_prompt_b.strip() + "\n" + generated_text
input_ids_a = tokenizer(prompt_a, return_tensors="pt").input_ids.to(device)
input_ids_b = tokenizer(prompt_b, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output_a = model_a(input_ids_a)
output_b = model_b(input_ids_b)
logits_a = output_a.logits[:, -1, :]
logits_b = output_b.logits[:, -1, :]
blended_logits = wa * logits_a + wb * logits_b
# Apply top-k filtering
top_k_logits, top_k_indices = torch.topk(blended_logits, top_k)
filtered_logits = torch.full_like(blended_logits, float('-inf'))
filtered_logits.scatter_(1, top_k_indices, top_k_logits)
# Temperature scaling
scaled_logits = filtered_logits / temperature
probs = torch.softmax(scaled_logits, dim=-1)
next_token = torch.multinomial(probs, 1).item()
if next_token == tokenizer.eos_token_id:
break
next_token_str = tokenizer.decode([next_token])
generated_text += next_token_str
return generated_text
with gr.Blocks() as demo:
system_prompt_a = gr.Textbox(label="System Prompt A", value="You are a helpful assistant.")
system_prompt_b = gr.Textbox(label="System Prompt B", value="You are a witty assistant.")
user_prompt = gr.Textbox(label="User Prompt", value="Tell me a story about a dragon.")
weight_a = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model A")
weight_b = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model B")
output_text = gr.Textbox(label="Output", lines=10)
btn = gr.Button("Generate")
btn.click(blend_generate, inputs=[system_prompt_a, system_prompt_b, user_prompt, weight_a, weight_b], outputs=output_text)
demo.launch()
|