Amossofer commited on
Commit
38b5252
·
1 Parent(s): 1782685
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -17,37 +17,41 @@ model_b = AutoModelForCausalLM.from_pretrained(model_name_b).to(device)
17
  model_a.eval()
18
  model_b.eval()
19
 
20
- def blend_generate(prompt, wa, wb):
21
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
 
 
 
 
 
 
22
 
23
  with torch.no_grad():
24
- output_a = model_a(input_ids)
25
- output_b = model_b(input_ids)
26
 
27
  logits_a = output_a.logits[:, -1, :]
28
  logits_b = output_b.logits[:, -1, :]
29
 
30
- # Weighted sum of raw logits (before softmax)
31
  blended_logits = wa * logits_a + wb * logits_b
32
 
33
- # Softmax to get probabilities
34
  probs = torch.softmax(blended_logits, dim=-1)
35
-
36
- # Sample one token from the blended distribution
37
  token = torch.multinomial(probs, 1)
38
  next_token_id = token.item()
39
-
40
  next_token = tokenizer.decode([next_token_id])
41
- return prompt + next_token
42
 
43
- # Gradio UI
 
 
44
  with gr.Blocks() as demo:
45
- prompt_input = gr.Textbox(label="Prompt", lines=2)
46
- weight_a = gr.Slider(0, 1, value=0.5, label="Weight model A")
47
- weight_b = gr.Slider(0, 1, value=0.5, label="Weight model B")
 
 
48
  output_text = gr.Textbox(label="Output")
49
 
50
  btn = gr.Button("Generate")
51
- btn.click(blend_generate, inputs=[prompt_input, weight_a, weight_b], outputs=output_text)
52
 
53
  demo.launch()
 
17
  model_a.eval()
18
  model_b.eval()
19
 
20
+ def blend_generate(system_prompt_a, system_prompt_b, user_prompt, wa, wb):
21
+ # Combine system prompt A + user prompt for model A
22
+ prompt_a = system_prompt_a + user_prompt
23
+ # Combine system prompt B + user prompt for model B
24
+ prompt_b = system_prompt_b + user_prompt
25
+
26
+ input_ids_a = tokenizer(prompt_a, return_tensors="pt").input_ids.to(device)
27
+ input_ids_b = tokenizer(prompt_b, return_tensors="pt").input_ids.to(device)
28
 
29
  with torch.no_grad():
30
+ output_a = model_a(input_ids_a)
31
+ output_b = model_b(input_ids_b)
32
 
33
  logits_a = output_a.logits[:, -1, :]
34
  logits_b = output_b.logits[:, -1, :]
35
 
 
36
  blended_logits = wa * logits_a + wb * logits_b
37
 
 
38
  probs = torch.softmax(blended_logits, dim=-1)
 
 
39
  token = torch.multinomial(probs, 1)
40
  next_token_id = token.item()
 
41
  next_token = tokenizer.decode([next_token_id])
 
42
 
43
+ # For simplicity, just return user prompt + next token (you can customize)
44
+ return user_prompt + next_token
45
+
46
  with gr.Blocks() as demo:
47
+ system_prompt_a = gr.Textbox(label="System Prompt A", value="You are a helpful assistant. ")
48
+ system_prompt_b = gr.Textbox(label="System Prompt B", value="You are a sarcastic assistant. ")
49
+ user_prompt = gr.Textbox(label="User Prompt")
50
+ weight_a = gr.Slider(0, 1, value=0.5, label="Weight Model A")
51
+ weight_b = gr.Slider(0, 1, value=0.5, label="Weight Model B")
52
  output_text = gr.Textbox(label="Output")
53
 
54
  btn = gr.Button("Generate")
55
+ btn.click(blend_generate, inputs=[system_prompt_a, system_prompt_b, user_prompt, weight_a, weight_b], outputs=output_text)
56
 
57
  demo.launch()