Amossofer commited on
Commit
142eb42
·
1 Parent(s): 2ef1e0a
Files changed (1) hide show
  1. app.py +50 -30
app.py CHANGED
@@ -1,46 +1,66 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
  import gradio as gr
3
  import torch
 
 
4
 
5
- # Load tiny model from Hugging Face
6
  model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
7
  tokenizer = AutoTokenizer.from_pretrained(model_id)
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
11
  )
 
12
 
13
- # Use text-generation pipeline (without `device=0`)
14
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
15
-
16
- # Function to blend two prompts with weights (wa and wb)
17
- def blend_and_generate(prompt_a, prompt_b, wa, wb):
18
- # Normalize weights even if negative
19
- total = abs(wa) + abs(wb)
20
- if total == 0:
21
- return "Error: Both weights cannot be zero."
22
- norm_wa = wa / total
23
- norm_wb = wb / total
24
-
25
- # Create blended prompt
26
- blended_prompt = f"{norm_wa:.2f} * ({prompt_a}) + {norm_wb:.2f} * ({prompt_b})"
27
- generated = generator(blended_prompt, max_new_tokens=100, do_sample=True, temperature=0.7)
28
- return generated[0]["generated_text"]
29
-
30
- # Gradio UI
31
- demo = gr.Interface(
32
- fn=blend_and_generate,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  inputs=[
34
- gr.Textbox(label="Prompt A"),
35
- gr.Textbox(label="Prompt B"),
36
- gr.Slider(minimum=-5, maximum=5, step=0.1, label="Weight A (wa)"),
37
- gr.Slider(minimum=-5, maximum=5, step=0.1, label="Weight B (wb)"),
 
 
 
 
38
  ],
39
- outputs=gr.Textbox(label="Generated Output"),
40
- title="Tiny Prompt Blender (TinyLlama-1.1B)",
41
- description="Enter two prompts and blend them using wa and wb (can be negative).",
42
  )
43
 
44
- # Launch app
45
  if __name__ == "__main__":
46
  demo.launch()
 
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
 
6
  model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
7
  tokenizer = AutoTokenizer.from_pretrained(model_id)
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
11
  )
12
+ model.eval()
13
 
14
+ def generate_stream(sysA, sysB, wA, wB, user_message, max_new_tokens=100, temperature=1.0, top_p=0.9):
15
+ promptA = f"<|system|>{sysA}\n<|user|>{user_message}\n<|assistant|>"
16
+ promptB = f"<|system|>{sysB}\n<|user|>{user_message}\n<|assistant|>"
17
+
18
+ idsA = tokenizer(promptA, return_tensors="pt").input_ids.to(model.device)
19
+ idsB = tokenizer(promptB, return_tensors="pt").input_ids.to(model.device)
20
+
21
+ outA, outB = idsA.clone(), idsB.clone()
22
+ response = ""
23
+ yield response # start stream
24
+
25
+ for _ in range(max_new_tokens):
26
+ with torch.no_grad():
27
+ logitsA = model(input_ids=outA).logits[:, -1, :]
28
+ logitsB = model(input_ids=outB).logits[:, -1, :]
29
+
30
+ blended = wA * logitsA + wB * logitsB
31
+ blended = blended / (temperature if temperature > 0 else 1.0)
32
+
33
+ probs = F.softmax(blended, dim=-1)
34
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True)
35
+ cum = torch.cumsum(sorted_probs, dim=-1)
36
+ sorted_probs[cum > top_p] = 0
37
+ sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
38
+
39
+ token = sorted_idx[:, torch.multinomial(sorted_probs, 1)].squeeze()
40
+ outA = torch.cat([outA, token.unsqueeze(0).unsqueeze(0)], dim=1)
41
+ outB = torch.cat([outB, token.unsqueeze(0).unsqueeze(0)], dim=1)
42
+
43
+ token_str = tokenizer.decode(token)
44
+ response += token_str
45
+ yield response
46
+
47
+ if token.item() == tokenizer.eos_token_id:
48
+ break
49
+
50
+ demo = gr.ChatInterface(
51
+ fn=generate_stream,
52
  inputs=[
53
+ gr.Textbox(label="System Prompt A", value="You are assistant A."),
54
+ gr.Textbox(label="System Prompt B", value="You are assistant B."),
55
+ gr.Slider(label="Weight wA", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
56
+ gr.Slider(label="Weight wB", minimum=-5.0, maximum=5.0, step=0.1, value=1.0),
57
+ gr.Textbox(label="User Message"),
58
+ gr.Slider(label="Max New Tokens", minimum=1, maximum=200, step=1, value=100),
59
+ gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0),
60
+ gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.9),
61
  ],
62
+ title="Streaming Blended TinyLlama Chat"
 
 
63
  )
64
 
 
65
  if __name__ == "__main__":
66
  demo.launch()