Amossofer commited on
Commit
3483699
·
1 Parent(s): bc5c17c
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -3,9 +3,9 @@ import torch
3
  import torch.nn.functional as F
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
- MODEL_ID = "tiiuae/falcon-rw-1b"
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
8
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to("cpu")
9
 
10
  def generate_stream(sysA, sysB, wa, wb, user_input, max_new_tokens=50, temperature=1.0, top_p=0.95):
11
  promptA = f"<|system|>{sysA}\n<|user|>{user_input}<|assistant|>"
@@ -17,15 +17,17 @@ def generate_stream(sysA, sysB, wa, wb, user_input, max_new_tokens=50, temperatu
17
  outA = idsA.clone()
18
  outB = idsB.clone()
19
  response = ""
20
- yield response # initial empty chunk
21
 
22
  for _ in range(max_new_tokens):
23
  with torch.no_grad():
24
  logitsA = model(input_ids=outA).logits[:, -1, :]
25
  logitsB = model(input_ids=outB).logits[:, -1, :]
26
 
 
27
  logits = wa * logitsA + wb * logitsB
28
  logits = logits / (temperature if temperature > 0 else 1.0)
 
29
  probs = F.softmax(logits, dim=-1)
30
 
31
  sorted_probs, sorted_idx = torch.sort(probs, descending=True)
@@ -44,22 +46,22 @@ def generate_stream(sysA, sysB, wa, wb, user_input, max_new_tokens=50, temperatu
44
  if token.item() == tokenizer.eos_token_id:
45
  break
46
 
47
- with gr.ChatInterface(
 
48
  fn=generate_stream,
49
  inputs=[
50
- gr.Textbox(label="System Prompt A", value="You are assistant A"),
51
- gr.Textbox(label="System Prompt B", value="You are assistant B"),
52
  gr.Slider(label="Weight wA", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
53
  gr.Slider(label="Weight wB", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
54
- gr.Textbox(label="User Message", placeholder="Enter your message here..."),
55
  gr.Slider(label="Max new tokens", minimum=1, maximum=200, step=1, value=50),
56
  gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0),
57
  gr.Slider(label="Top‑p", minimum=0.1, maximum=1.0, step=0.05, value=0.95),
58
  ],
59
- title="Blended Two-System Streaming Chat",
60
- description="Stream replies by blending logits from two system-prompts using weights wA and wB.",
61
- ):
62
- pass
63
 
64
  if __name__ == "__main__":
65
  demo.launch()
 
3
  import torch.nn.functional as F
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
+ MODEL_ID = "tiiuae/falcon-rw-1b" # small model for local use
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
8
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to("cpu") # or "cuda" if available
9
 
10
  def generate_stream(sysA, sysB, wa, wb, user_input, max_new_tokens=50, temperature=1.0, top_p=0.95):
11
  promptA = f"<|system|>{sysA}\n<|user|>{user_input}<|assistant|>"
 
17
  outA = idsA.clone()
18
  outB = idsB.clone()
19
  response = ""
20
+ yield response # send initial blank to start stream
21
 
22
  for _ in range(max_new_tokens):
23
  with torch.no_grad():
24
  logitsA = model(input_ids=outA).logits[:, -1, :]
25
  logitsB = model(input_ids=outB).logits[:, -1, :]
26
 
27
+ # Weighted average of logits
28
  logits = wa * logitsA + wb * logitsB
29
  logits = logits / (temperature if temperature > 0 else 1.0)
30
+
31
  probs = F.softmax(logits, dim=-1)
32
 
33
  sorted_probs, sorted_idx = torch.sort(probs, descending=True)
 
46
  if token.item() == tokenizer.eos_token_id:
47
  break
48
 
49
+ # ✅ Define the demo interface correctly
50
+ demo = gr.ChatInterface(
51
  fn=generate_stream,
52
  inputs=[
53
+ gr.Textbox(label="System Prompt A", value="You are assistant A."),
54
+ gr.Textbox(label="System Prompt B", value="You are assistant B."),
55
  gr.Slider(label="Weight wA", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
56
  gr.Slider(label="Weight wB", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
57
+ gr.Textbox(label="User Message", placeholder="Enter your message..."),
58
  gr.Slider(label="Max new tokens", minimum=1, maximum=200, step=1, value=50),
59
  gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0),
60
  gr.Slider(label="Top‑p", minimum=0.1, maximum=1.0, step=0.05, value=0.95),
61
  ],
62
+ title="Two-System Weighted Blending Chat",
63
+ description="Combines two system prompts using weighted logit blending: response = wA⋅modelA + wB⋅modelB.",
64
+ )
 
65
 
66
  if __name__ == "__main__":
67
  demo.launch()