Amossofer commited on
Commit
f45c0a2
·
1 Parent(s): 142eb42
Files changed (2) hide show
  1. app.py +22 -17
  2. requirements.txt +3 -4
app.py CHANGED
@@ -11,7 +11,7 @@ model = AutoModelForCausalLM.from_pretrained(
11
  )
12
  model.eval()
13
 
14
- def generate_stream(sysA, sysB, wA, wB, user_message, max_new_tokens=100, temperature=1.0, top_p=0.9):
15
  promptA = f"<|system|>{sysA}\n<|user|>{user_message}\n<|assistant|>"
16
  promptB = f"<|system|>{sysB}\n<|user|>{user_message}\n<|assistant|>"
17
 
@@ -20,7 +20,6 @@ def generate_stream(sysA, sysB, wA, wB, user_message, max_new_tokens=100, temper
20
 
21
  outA, outB = idsA.clone(), idsB.clone()
22
  response = ""
23
- yield response # start stream
24
 
25
  for _ in range(max_new_tokens):
26
  with torch.no_grad():
@@ -28,7 +27,7 @@ def generate_stream(sysA, sysB, wA, wB, user_message, max_new_tokens=100, temper
28
  logitsB = model(input_ids=outB).logits[:, -1, :]
29
 
30
  blended = wA * logitsA + wB * logitsB
31
- blended = blended / (temperature if temperature > 0 else 1.0)
32
 
33
  probs = F.softmax(blended, dim=-1)
34
  sorted_probs, sorted_idx = torch.sort(probs, descending=True)
@@ -47,20 +46,26 @@ def generate_stream(sysA, sysB, wA, wB, user_message, max_new_tokens=100, temper
47
  if token.item() == tokenizer.eos_token_id:
48
  break
49
 
50
- demo = gr.ChatInterface(
51
- fn=generate_stream,
52
- inputs=[
53
- gr.Textbox(label="System Prompt A", value="You are assistant A."),
54
- gr.Textbox(label="System Prompt B", value="You are assistant B."),
55
- gr.Slider(label="Weight wA", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
56
- gr.Slider(label="Weight wB", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
57
- gr.Textbox(label="User Message"),
58
- gr.Slider(label="Max New Tokens", minimum=1, maximum=200, step=1, value=100),
59
- gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0),
60
- gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.9),
61
- ],
62
- title="Streaming Blended TinyLlama Chat"
63
- )
 
 
 
 
 
 
64
 
65
  if __name__ == "__main__":
66
  demo.launch()
 
11
  )
12
  model.eval()
13
 
14
+ def blend_generate(sysA, sysB, wA, wB, user_message, max_new_tokens, temperature, top_p):
15
  promptA = f"<|system|>{sysA}\n<|user|>{user_message}\n<|assistant|>"
16
  promptB = f"<|system|>{sysB}\n<|user|>{user_message}\n<|assistant|>"
17
 
 
20
 
21
  outA, outB = idsA.clone(), idsB.clone()
22
  response = ""
 
23
 
24
  for _ in range(max_new_tokens):
25
  with torch.no_grad():
 
27
  logitsB = model(input_ids=outB).logits[:, -1, :]
28
 
29
  blended = wA * logitsA + wB * logitsB
30
+ blended = blended / temperature
31
 
32
  probs = F.softmax(blended, dim=-1)
33
  sorted_probs, sorted_idx = torch.sort(probs, descending=True)
 
46
  if token.item() == tokenizer.eos_token_id:
47
  break
48
 
49
+ with gr.Blocks() as demo:
50
+ gr.Markdown("## Blended Prompt Chat (TinyLlama)")
51
+ sysA = gr.Textbox(label="System Prompt A", value="You are assistant A.")
52
+ sysB = gr.Textbox(label="System Prompt B", value="You are assistant B.")
53
+ wA = gr.Slider(-5, 5, value=1.0, step=0.1, label="Weight A")
54
+ wB = gr.Slider(-5, 5, value=1.0, step=0.1, label="Weight B")
55
+ user_msg = gr.Textbox(label="User Message")
56
+ temp = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
57
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
58
+ max_tokens = gr.Slider(1, 200, value=100, step=1, label="Max New Tokens")
59
+ output = gr.Textbox(label="Response")
60
+
61
+ btn = gr.Button("Generate")
62
+ btn.click(
63
+ blend_generate,
64
+ [sysA, sysB, wA, wB, user_msg, max_tokens, temp, top_p],
65
+ output,
66
+ show_progress=True,
67
+ stream=True
68
+ )
69
 
70
  if __name__ == "__main__":
71
  demo.launch()
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- transformers>=4.31
2
- torch
3
- gradio
4
- accelerate
 
1
+ gradio>=3.50.0
2
+ transformers>=4.40.0
3
+ torch>=2.2.0