File size: 2,822 Bytes
2ef1e0a
82ff832
1782685
 
 
 
c928282
 
1782685
82ff832
1782685
2ef1e0a
82ff832
 
 
1782685
 
 
2ef1e0a
82ff832
 
51d9d55
38b5252
51d9d55
82ff832
 
142eb42
51d9d55
 
142eb42
51d9d55
 
 
142eb42
51d9d55
 
142eb42
51d9d55
142eb42
82ff832
 
 
 
51d9d55
82ff832
 
 
 
 
 
 
 
51d9d55
 
82ff832
 
51d9d55
 
38b5252
82ff832
f45c0a2
82ff832
 
 
 
 
 
f45c0a2
 
38b5252
1782685
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name_a = "mistralai/Mistral-7B-v0.1"
model_name_b = "mistralai/Mistral-7B-v0.1"  # you can replace this with a second different model or finetuned variant

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name_a)

print("Loading models...")
model_a = AutoModelForCausalLM.from_pretrained(model_name_a, device_map="auto", torch_dtype=torch.float16)
model_b = AutoModelForCausalLM.from_pretrained(model_name_b, device_map="auto", torch_dtype=torch.float16)

model_a.eval()
model_b.eval()

def blend_generate(system_prompt_a, system_prompt_b, user_prompt, wa, wb, max_length=50, temperature=0.7, top_k=50):
    device = next(model_a.parameters()).device
    generated_text = user_prompt

    for _ in range(max_length):
        prompt_a = system_prompt_a.strip() + "\n" + generated_text
        prompt_b = system_prompt_b.strip() + "\n" + generated_text

        input_ids_a = tokenizer(prompt_a, return_tensors="pt").input_ids.to(device)
        input_ids_b = tokenizer(prompt_b, return_tensors="pt").input_ids.to(device)

        with torch.no_grad():
            output_a = model_a(input_ids_a)
            output_b = model_b(input_ids_b)

        logits_a = output_a.logits[:, -1, :]
        logits_b = output_b.logits[:, -1, :]

        blended_logits = wa * logits_a + wb * logits_b

        # Apply top-k filtering
        top_k_logits, top_k_indices = torch.topk(blended_logits, top_k)
        filtered_logits = torch.full_like(blended_logits, float('-inf'))
        filtered_logits.scatter_(1, top_k_indices, top_k_logits)

        # Temperature scaling
        scaled_logits = filtered_logits / temperature

        probs = torch.softmax(scaled_logits, dim=-1)

        next_token = torch.multinomial(probs, 1).item()

        if next_token == tokenizer.eos_token_id:
            break

        next_token_str = tokenizer.decode([next_token])
        generated_text += next_token_str

    return generated_text


with gr.Blocks() as demo:
    system_prompt_a = gr.Textbox(label="System Prompt A", value="You are a helpful assistant.")
    system_prompt_b = gr.Textbox(label="System Prompt B", value="You are a witty assistant.")
    user_prompt = gr.Textbox(label="User Prompt", value="Tell me a story about a dragon.")
    weight_a = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model A")
    weight_b = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model B")
    output_text = gr.Textbox(label="Output", lines=10)

    btn = gr.Button("Generate")
    btn.click(blend_generate, inputs=[system_prompt_a, system_prompt_b, user_prompt, weight_a, weight_b], outputs=output_text)

demo.launch()