Amossofer commited on
Commit
3ff7c04
·
1 Parent(s): b30d3bb
Files changed (1) hide show
  1. app.py +64 -59
app.py CHANGED
@@ -1,65 +1,70 @@
1
  import gradio as gr
 
2
  import torch
3
  import torch.nn.functional as F
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
-
6
- # Model selection
7
- MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
8
 
 
 
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
10
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto")
11
-
12
- def chat_blend_stream(systemA, systemB, wA, wB, user_message, max_new_tokens=50, temperature=1.0, top_p=0.95):
13
- promptA = systemA + "\nUser: " + user_message + "\nAssistant:"
14
- promptB = systemB + "\nUser: " + user_message + "\nAssistant:"
15
-
16
- input_ids_A = tokenizer(promptA, return_tensors="pt").input_ids.to(model.device)
17
- input_ids_B = tokenizer(promptB, return_tensors="pt").input_ids.to(model.device)
18
- output_ids = input_ids_A.clone()
19
-
20
- response_text = ""
21
-
22
- for _ in range(max_new_tokens):
23
- logitsA = model(input_ids=input_ids_A).logits[:, -1, :]
24
- logitsB = model(input_ids=input_ids_B).logits[:, -1, :]
25
-
26
- probsA = F.softmax(logitsA / temperature, dim=-1)
27
- probsB = F.softmax(logitsB / temperature, dim=-1)
28
-
29
- blended_probs = wA * probsA + wB * probsB
30
- sorted_probs, sorted_idx = torch.sort(blended_probs, descending=True)
31
- cum = torch.cumsum(sorted_probs, dim=-1)
32
- sorted_probs[cum > top_p] = 0
33
- sorted_probs /= sorted_probs.sum()
34
-
35
- new_id = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)]
36
- output_ids = torch.cat([output_ids, new_id.unsqueeze(0).unsqueeze(0)], dim=1)
37
- input_ids_A = torch.cat([input_ids_A, new_id.unsqueeze(0).unsqueeze(0)], dim=1)
38
- input_ids_B = torch.cat([input_ids_B, new_id.unsqueeze(0).unsqueeze(0)], dim=1)
39
-
40
- token_text = tokenizer.decode(new_id)
41
- response_text += token_text
42
- yield response_text
43
-
44
- if new_id.item() == tokenizer.eos_token_id:
45
- break
46
-
47
- demo = gr.ChatInterface(
48
- fn=chat_blend_stream,
49
- additional_inputs=[
50
- gr.Textbox(label="System Prompt A", value="You are assistant A."),
51
- gr.Textbox(label="System Prompt B", value="You are assistant B."),
52
- gr.Slider(label="wA", minimum=-2.0, maximum=2.0, step=0.1, value=1.0),
53
- gr.Slider(label="wB", minimum=-2.0, maximum=2.0, step=0.1, value=1.0),
54
- gr.Slider(label="Max new tokens", minimum=1, maximum=200, step=1, value=50),
55
- gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0),
56
- gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.95),
57
- ],
58
- title="Blended TinyLlama",
59
- description="Two-system prompts, blended logits with negative/positive weights."
60
- )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
65
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import torch.nn.functional as F
 
 
 
 
5
 
6
+ # Load tiny model (CPU-friendly)
7
+ MODEL_ID = "tiiuae/falcon-rw-1b"
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
9
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to("cpu")
10
+
11
+
12
+ def get_logits(prompt, system_msg):
13
+ """Run the model and return logits for the next token."""
14
+ input_text = f"<|system|>{system_msg}\n<|user|>{prompt}<|assistant|>"
15
+ inputs = tokenizer(input_text, return_tensors="pt")
16
+ with torch.no_grad():
17
+ outputs = model(**inputs)
18
+ logits = outputs.logits[:, -1, :] # Only final token logits
19
+ return logits
20
+
21
+
22
+ def blended_generate(prompt, sys1, sys2, wa, wb, temperature=1.0):
23
+ # Get logits from both system prompts
24
+ logits1 = get_logits(prompt, sys1)
25
+ logits2 = get_logits(prompt, sys2)
26
+
27
+ # Weighted sum
28
+ blended_logits = wa * logits1 + wb * logits2
29
+
30
+ # Apply temperature
31
+ blended_logits = blended_logits / temperature
32
+
33
+ # Convert to probabilities
34
+ probs = F.softmax(blended_logits, dim=-1)
35
+
36
+ # Sample one token from the distribution
37
+ token_id = torch.multinomial(probs, num_samples=1)
38
+ next_token = tokenizer.decode(token_id[0])
39
+
40
+ return next_token.strip()
41
+
42
+
43
+ # Gradio UI
44
+ with gr.Blocks() as demo:
45
+ gr.Markdown("## 🔀 Blended System Prompts using Falcon-RW-1B")
46
+
47
+ with gr.Row():
48
+ prompt = gr.Textbox(label="User Prompt", value="Tell me a joke about computers.")
49
+
50
+ with gr.Row():
51
+ sys1 = gr.Textbox(label="System Prompt A", value="You are a polite assistant.")
52
+ sys2 = gr.Textbox(label="System Prompt B", value="You are a sarcastic assistant.")
53
+
54
+ with gr.Row():
55
+ wa = gr.Slider(-10, 10, value=1.0, step=0.1, label="Weight A")
56
+ wb = gr.Slider(-10, 10, value=1.0, step=0.1, label="Weight B")
57
+
58
+ with gr.Row():
59
+ temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
60
+
61
+ output = gr.Textbox(label="Next Token")
62
+
63
+ generate_btn = gr.Button("Generate Next Token")
64
+ generate_btn.click(
65
+ fn=blended_generate,
66
+ inputs=[prompt, sys1, sys2, wa, wb, temperature],
67
+ outputs=output,
68
+ )
69
 
70
+ demo.launch()