Amossofer commited on
Commit
592da02
·
1 Parent(s): 3483699
Files changed (1) hide show
  1. app.py +34 -62
app.py CHANGED
@@ -1,67 +1,39 @@
1
  import gradio as gr
2
- import torch
3
- import torch.nn.functional as F
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
- MODEL_ID = "tiiuae/falcon-rw-1b" # small model for local use
7
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
8
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to("cpu") # or "cuda" if available
9
-
10
- def generate_stream(sysA, sysB, wa, wb, user_input, max_new_tokens=50, temperature=1.0, top_p=0.95):
11
- promptA = f"<|system|>{sysA}\n<|user|>{user_input}<|assistant|>"
12
- promptB = f"<|system|>{sysB}\n<|user|>{user_input}<|assistant|>"
13
-
14
- idsA = tokenizer(promptA, return_tensors="pt").input_ids.to(model.device)
15
- idsB = tokenizer(promptB, return_tensors="pt").input_ids.to(model.device)
16
-
17
- outA = idsA.clone()
18
- outB = idsB.clone()
19
- response = ""
20
- yield response # send initial blank to start stream
21
-
22
- for _ in range(max_new_tokens):
23
- with torch.no_grad():
24
- logitsA = model(input_ids=outA).logits[:, -1, :]
25
- logitsB = model(input_ids=outB).logits[:, -1, :]
26
-
27
- # Weighted average of logits
28
- logits = wa * logitsA + wb * logitsB
29
- logits = logits / (temperature if temperature > 0 else 1.0)
30
-
31
- probs = F.softmax(logits, dim=-1)
32
-
33
- sorted_probs, sorted_idx = torch.sort(probs, descending=True)
34
- cumulative = torch.cumsum(sorted_probs, dim=-1)
35
- sorted_probs[cumulative > top_p] = 0
36
- sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
37
-
38
- token = sorted_idx[:, torch.multinomial(sorted_probs, 1)].squeeze()
39
- outA = torch.cat([outA, token.unsqueeze(0).unsqueeze(0)], dim=1)
40
- outB = torch.cat([outB, token.unsqueeze(0).unsqueeze(0)], dim=1)
41
-
42
- token_str = tokenizer.decode(token)
43
- response += token_str
44
- yield response
45
-
46
- if token.item() == tokenizer.eos_token_id:
47
- break
48
-
49
- # ✅ Define the demo interface correctly
50
- demo = gr.ChatInterface(
51
- fn=generate_stream,
52
- inputs=[
53
- gr.Textbox(label="System Prompt A", value="You are assistant A."),
54
- gr.Textbox(label="System Prompt B", value="You are assistant B."),
55
- gr.Slider(label="Weight wA", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
56
- gr.Slider(label="Weight wB", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
57
- gr.Textbox(label="User Message", placeholder="Enter your message..."),
58
- gr.Slider(label="Max new tokens", minimum=1, maximum=200, step=1, value=50),
59
- gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0),
60
- gr.Slider(label="Top‑p", minimum=0.1, maximum=1.0, step=0.05, value=0.95),
61
- ],
62
- title="Two-System Weighted Blending Chat",
63
- description="Combines two system prompts using weighted logit blending: response = wA⋅modelA + wB⋅modelB.",
64
- )
65
 
66
  if __name__ == "__main__":
67
  demo.launch()
 
1
  import gradio as gr
 
 
 
2
 
3
+ def generate(sysA, sysB, wa, wb, user_input):
4
+ # Example blending logic — replace with your actual model call
5
+ response = (
6
+ f"System Prompt A: {sysA}\n"
7
+ f"System Prompt B: {sysB}\n"
8
+ f"Weight A: {wa}\n"
9
+ f"Weight B: {wb}\n"
10
+ f"User message: {user_input}\n\n"
11
+ "=== Response ===\n"
12
+ f"Blended response based on weights."
13
+ )
14
+ return response
15
+
16
+ with gr.Blocks() as demo:
17
+ gr.Markdown("# Multi-System Prompt Chat Demo")
18
+
19
+ with gr.Row():
20
+ sysA = gr.Textbox(label="System Prompt A", value="You are assistant A.", lines=2)
21
+ sysB = gr.Textbox(label="System Prompt B", value="You are assistant B.", lines=2)
22
+
23
+ with gr.Row():
24
+ wa = gr.Slider(-5.0, 5.0, value=1.0, step=0.1, label="Weight wA")
25
+ wb = gr.Slider(-5.0, 5.0, value=1.0, step=0.1, label="Weight wB")
26
+
27
+ user_input = gr.Textbox(label="User Message", placeholder="Type your message here...")
28
+ output = gr.Textbox(label="Model Response", lines=10)
29
+
30
+ submit_btn = gr.Button("Send")
31
+
32
+ submit_btn.click(
33
+ fn=generate,
34
+ inputs=[sysA, sysB, wa, wb, user_input],
35
+ outputs=output
36
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  if __name__ == "__main__":
39
  demo.launch()