Amossofer commited on
Commit
bc5c17c
·
1 Parent(s): 3ff7c04
Files changed (1) hide show
  1. app.py +57 -62
app.py CHANGED
@@ -1,70 +1,65 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import torch.nn.functional as F
 
5
 
6
- # Load tiny model (CPU-friendly)
7
  MODEL_ID = "tiiuae/falcon-rw-1b"
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
9
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to("cpu")
10
 
11
-
12
- def get_logits(prompt, system_msg):
13
- """Run the model and return logits for the next token."""
14
- input_text = f"<|system|>{system_msg}\n<|user|>{prompt}<|assistant|>"
15
- inputs = tokenizer(input_text, return_tensors="pt")
16
- with torch.no_grad():
17
- outputs = model(**inputs)
18
- logits = outputs.logits[:, -1, :] # Only final token logits
19
- return logits
20
-
21
-
22
- def blended_generate(prompt, sys1, sys2, wa, wb, temperature=1.0):
23
- # Get logits from both system prompts
24
- logits1 = get_logits(prompt, sys1)
25
- logits2 = get_logits(prompt, sys2)
26
-
27
- # Weighted sum
28
- blended_logits = wa * logits1 + wb * logits2
29
-
30
- # Apply temperature
31
- blended_logits = blended_logits / temperature
32
-
33
- # Convert to probabilities
34
- probs = F.softmax(blended_logits, dim=-1)
35
-
36
- # Sample one token from the distribution
37
- token_id = torch.multinomial(probs, num_samples=1)
38
- next_token = tokenizer.decode(token_id[0])
39
-
40
- return next_token.strip()
41
-
42
-
43
- # Gradio UI
44
- with gr.Blocks() as demo:
45
- gr.Markdown("## 🔀 Blended System Prompts using Falcon-RW-1B")
46
-
47
- with gr.Row():
48
- prompt = gr.Textbox(label="User Prompt", value="Tell me a joke about computers.")
49
-
50
- with gr.Row():
51
- sys1 = gr.Textbox(label="System Prompt A", value="You are a polite assistant.")
52
- sys2 = gr.Textbox(label="System Prompt B", value="You are a sarcastic assistant.")
53
-
54
- with gr.Row():
55
- wa = gr.Slider(-10, 10, value=1.0, step=0.1, label="Weight A")
56
- wb = gr.Slider(-10, 10, value=1.0, step=0.1, label="Weight B")
57
-
58
- with gr.Row():
59
- temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
60
-
61
- output = gr.Textbox(label="Next Token")
62
-
63
- generate_btn = gr.Button("Generate Next Token")
64
- generate_btn.click(
65
- fn=blended_generate,
66
- inputs=[prompt, sys1, sys2, wa, wb, temperature],
67
- outputs=output,
68
- )
69
-
70
- demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
  import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
 
6
  MODEL_ID = "tiiuae/falcon-rw-1b"
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
8
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to("cpu")
9
 
10
+ def generate_stream(sysA, sysB, wa, wb, user_input, max_new_tokens=50, temperature=1.0, top_p=0.95):
11
+ promptA = f"<|system|>{sysA}\n<|user|>{user_input}<|assistant|>"
12
+ promptB = f"<|system|>{sysB}\n<|user|>{user_input}<|assistant|>"
13
+
14
+ idsA = tokenizer(promptA, return_tensors="pt").input_ids.to(model.device)
15
+ idsB = tokenizer(promptB, return_tensors="pt").input_ids.to(model.device)
16
+
17
+ outA = idsA.clone()
18
+ outB = idsB.clone()
19
+ response = ""
20
+ yield response # initial empty chunk
21
+
22
+ for _ in range(max_new_tokens):
23
+ with torch.no_grad():
24
+ logitsA = model(input_ids=outA).logits[:, -1, :]
25
+ logitsB = model(input_ids=outB).logits[:, -1, :]
26
+
27
+ logits = wa * logitsA + wb * logitsB
28
+ logits = logits / (temperature if temperature > 0 else 1.0)
29
+ probs = F.softmax(logits, dim=-1)
30
+
31
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True)
32
+ cumulative = torch.cumsum(sorted_probs, dim=-1)
33
+ sorted_probs[cumulative > top_p] = 0
34
+ sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
35
+
36
+ token = sorted_idx[:, torch.multinomial(sorted_probs, 1)].squeeze()
37
+ outA = torch.cat([outA, token.unsqueeze(0).unsqueeze(0)], dim=1)
38
+ outB = torch.cat([outB, token.unsqueeze(0).unsqueeze(0)], dim=1)
39
+
40
+ token_str = tokenizer.decode(token)
41
+ response += token_str
42
+ yield response
43
+
44
+ if token.item() == tokenizer.eos_token_id:
45
+ break
46
+
47
+ with gr.ChatInterface(
48
+ fn=generate_stream,
49
+ inputs=[
50
+ gr.Textbox(label="System Prompt A", value="You are assistant A"),
51
+ gr.Textbox(label="System Prompt B", value="You are assistant B"),
52
+ gr.Slider(label="Weight wA", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
53
+ gr.Slider(label="Weight wB", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
54
+ gr.Textbox(label="User Message", placeholder="Enter your message here..."),
55
+ gr.Slider(label="Max new tokens", minimum=1, maximum=200, step=1, value=50),
56
+ gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0),
57
+ gr.Slider(label="Top‑p", minimum=0.1, maximum=1.0, step=0.05, value=0.95),
58
+ ],
59
+ title="Blended Two-System Streaming Chat",
60
+ description="Stream replies by blending logits from two system-prompts using weights wA and wB.",
61
+ ):
62
+ pass
63
+
64
+ if __name__ == "__main__":
65
+ demo.launch()