Amossofer commited on
Commit
b30d3bb
·
1 Parent(s): 7a8e287
Files changed (1) hide show
  1. app.py +30 -30
app.py CHANGED
@@ -9,57 +9,57 @@ MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
10
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto")
11
 
12
- def chat_blend(systemA, systemB, wA, wB, user_message, max_new_tokens=50, temperature=1.0, top_p=0.95):
13
- # Build messages
14
  promptA = systemA + "\nUser: " + user_message + "\nAssistant:"
15
  promptB = systemB + "\nUser: " + user_message + "\nAssistant:"
 
16
  input_ids_A = tokenizer(promptA, return_tensors="pt").input_ids.to(model.device)
17
  input_ids_B = tokenizer(promptB, return_tensors="pt").input_ids.to(model.device)
18
- output_ids = input_ids_A # share history; or keep separate
 
 
19
 
20
  for _ in range(max_new_tokens):
21
- # forward last token
22
  logitsA = model(input_ids=input_ids_A).logits[:, -1, :]
23
  logitsB = model(input_ids=input_ids_B).logits[:, -1, :]
 
24
  probsA = F.softmax(logitsA / temperature, dim=-1)
25
  probsB = F.softmax(logitsB / temperature, dim=-1)
26
- blended = wA * probsA + wB * probsB
27
- # apply top_p
28
- sorted_probs, sorted_idx = torch.sort(blended, descending=True)
29
  cum = torch.cumsum(sorted_probs, dim=-1)
30
- mask = cum > top_p
31
- sorted_probs[mask] = 0
32
- sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)
33
- # repeat mapping back
34
- new_id = sorted_idx.gather(-1, torch.multinomial(sorted_probs, num_samples=1))
35
- output_ids = torch.cat([output_ids, new_id], dim=-1)
36
- # append to each history
37
- input_ids_A = torch.cat([input_ids_A, new_id], dim=-1)
38
- input_ids_B = torch.cat([input_ids_B, new_id], dim=-1)
39
- # stop on EOS
 
 
40
  if new_id.item() == tokenizer.eos_token_id:
41
  break
42
 
43
- decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
44
- # strip system + user prompts
45
- return decoded.split("Assistant:")[-1].strip()
46
-
47
- iface = gr.Interface(
48
- fn=chat_blend,
49
- inputs=[
50
  gr.Textbox(label="System Prompt A", value="You are assistant A."),
51
  gr.Textbox(label="System Prompt B", value="You are assistant B."),
52
  gr.Slider(label="wA", minimum=-2.0, maximum=2.0, step=0.1, value=1.0),
53
  gr.Slider(label="wB", minimum=-2.0, maximum=2.0, step=0.1, value=1.0),
54
- gr.Textbox(label="User message"),
55
  gr.Slider(label="Max new tokens", minimum=1, maximum=200, step=1, value=50),
56
  gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0),
57
- gr.Slider(label="Topp", minimum=0.1, maximum=1.0, step=0.05, value=0.95),
58
  ],
59
- outputs="text",
60
- title="Blended‑LLM Chat (TinyLlama)",
61
- description="Uses two system prompts and blends their token distributions using wA*p1 + wB*p2."
62
  )
63
 
 
64
  if __name__ == "__main__":
65
- iface.launch()
 
 
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
10
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto")
11
 
12
+ def chat_blend_stream(systemA, systemB, wA, wB, user_message, max_new_tokens=50, temperature=1.0, top_p=0.95):
 
13
  promptA = systemA + "\nUser: " + user_message + "\nAssistant:"
14
  promptB = systemB + "\nUser: " + user_message + "\nAssistant:"
15
+
16
  input_ids_A = tokenizer(promptA, return_tensors="pt").input_ids.to(model.device)
17
  input_ids_B = tokenizer(promptB, return_tensors="pt").input_ids.to(model.device)
18
+ output_ids = input_ids_A.clone()
19
+
20
+ response_text = ""
21
 
22
  for _ in range(max_new_tokens):
 
23
  logitsA = model(input_ids=input_ids_A).logits[:, -1, :]
24
  logitsB = model(input_ids=input_ids_B).logits[:, -1, :]
25
+
26
  probsA = F.softmax(logitsA / temperature, dim=-1)
27
  probsB = F.softmax(logitsB / temperature, dim=-1)
28
+
29
+ blended_probs = wA * probsA + wB * probsB
30
+ sorted_probs, sorted_idx = torch.sort(blended_probs, descending=True)
31
  cum = torch.cumsum(sorted_probs, dim=-1)
32
+ sorted_probs[cum > top_p] = 0
33
+ sorted_probs /= sorted_probs.sum()
34
+
35
+ new_id = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)]
36
+ output_ids = torch.cat([output_ids, new_id.unsqueeze(0).unsqueeze(0)], dim=1)
37
+ input_ids_A = torch.cat([input_ids_A, new_id.unsqueeze(0).unsqueeze(0)], dim=1)
38
+ input_ids_B = torch.cat([input_ids_B, new_id.unsqueeze(0).unsqueeze(0)], dim=1)
39
+
40
+ token_text = tokenizer.decode(new_id)
41
+ response_text += token_text
42
+ yield response_text
43
+
44
  if new_id.item() == tokenizer.eos_token_id:
45
  break
46
 
47
+ demo = gr.ChatInterface(
48
+ fn=chat_blend_stream,
49
+ additional_inputs=[
 
 
 
 
50
  gr.Textbox(label="System Prompt A", value="You are assistant A."),
51
  gr.Textbox(label="System Prompt B", value="You are assistant B."),
52
  gr.Slider(label="wA", minimum=-2.0, maximum=2.0, step=0.1, value=1.0),
53
  gr.Slider(label="wB", minimum=-2.0, maximum=2.0, step=0.1, value=1.0),
 
54
  gr.Slider(label="Max new tokens", minimum=1, maximum=200, step=1, value=50),
55
  gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0),
56
+ gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.95),
57
  ],
58
+ title="Blended TinyLlama",
59
+ description="Two-system prompts, blended logits with negative/positive weights."
 
60
  )
61
 
62
+
63
  if __name__ == "__main__":
64
+ demo.launch()
65
+