Amossofer commited on
Commit
82ff832
·
1 Parent(s): f1b73e6
Files changed (1) hide show
  1. app.py +32 -26
app.py CHANGED
@@ -1,31 +1,29 @@
1
  import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import gradio as gr
4
 
5
- # Set device: GPU if available, else CPU
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
- # Load two small models and their tokenizer (you can replace these with your models)
9
- MODEL_NAME = "arnir0/Tiny-LLM"
10
- model_name_a = MODEL_NAME
11
- model_name_b = MODEL_NAME
12
 
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_name_a)
14
 
15
- model_a = AutoModelForCausalLM.from_pretrained(model_name_a).to(device)
16
- model_b = AutoModelForCausalLM.from_pretrained(model_name_b).to(device)
 
17
 
18
  model_a.eval()
19
  model_b.eval()
20
 
21
- def blend_generate(system_prompt_a, system_prompt_b, user_prompt, wa, wb, max_length=50):
 
22
  generated_text = user_prompt
23
- device = next(model_a.parameters()).device # infer device from model
24
 
25
  for _ in range(max_length):
26
- # Prepare prompts for each model: system prompt + generated text so far
27
- prompt_a = system_prompt_a + generated_text
28
- prompt_b = system_prompt_b + generated_text
29
 
30
  input_ids_a = tokenizer(prompt_a, return_tensors="pt").input_ids.to(device)
31
  input_ids_b = tokenizer(prompt_b, return_tensors="pt").input_ids.to(device)
@@ -39,26 +37,34 @@ def blend_generate(system_prompt_a, system_prompt_b, user_prompt, wa, wb, max_le
39
 
40
  blended_logits = wa * logits_a + wb * logits_b
41
 
42
- probs = torch.softmax(blended_logits, dim=-1)
43
- token = torch.multinomial(probs, 1)
44
- next_token_id = token.item()
 
45
 
46
- # Stop if end-of-sequence token generated (adjust based on your tokenizer)
47
- if next_token_id == tokenizer.eos_token_id:
 
 
 
 
 
 
48
  break
49
 
50
- next_token = tokenizer.decode([next_token_id])
51
- generated_text += next_token
52
 
53
  return generated_text
54
 
 
55
  with gr.Blocks() as demo:
56
- system_prompt_a = gr.Textbox(label="System Prompt A", value="You are a funny assistant. ")
57
- system_prompt_b = gr.Textbox(label="System Prompt B", value="You are a angry assistant. ")
58
- user_prompt = gr.Textbox(label="User Prompt",value="tell me a story")
59
- weight_a = gr.Slider(-2, 2, value=1, label="Weight Model A")
60
- weight_b = gr.Slider(-2, 2, value=1, label="Weight Model B")
61
- output_text = gr.Textbox(label="Output")
62
 
63
  btn = gr.Button("Generate")
64
  btn.click(blend_generate, inputs=[system_prompt_a, system_prompt_b, user_prompt, weight_a, weight_b], outputs=output_text)
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
 
 
5
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
 
7
+ model_name_a = "meta-llama/Llama-2-7b-chat-hf"
8
+ model_name_b = "meta-llama/Llama-2-7b-chat-hf" # you can replace this with a second different model or finetuned variant
 
 
9
 
10
+ print("Loading tokenizer...")
11
  tokenizer = AutoTokenizer.from_pretrained(model_name_a)
12
 
13
+ print("Loading models...")
14
+ model_a = AutoModelForCausalLM.from_pretrained(model_name_a, device_map="auto", torch_dtype=torch.float16)
15
+ model_b = AutoModelForCausalLM.from_pretrained(model_name_b, device_map="auto", torch_dtype=torch.float16)
16
 
17
  model_a.eval()
18
  model_b.eval()
19
 
20
+ def blend_generate(system_prompt_a, system_prompt_b, user_prompt, wa, wb, max_length=50, temperature=0.7, top_k=50):
21
+ device = next(model_a.parameters()).device
22
  generated_text = user_prompt
 
23
 
24
  for _ in range(max_length):
25
+ prompt_a = system_prompt_a.strip() + "\n" + generated_text
26
+ prompt_b = system_prompt_b.strip() + "\n" + generated_text
 
27
 
28
  input_ids_a = tokenizer(prompt_a, return_tensors="pt").input_ids.to(device)
29
  input_ids_b = tokenizer(prompt_b, return_tensors="pt").input_ids.to(device)
 
37
 
38
  blended_logits = wa * logits_a + wb * logits_b
39
 
40
+ # Apply top-k filtering
41
+ top_k_logits, top_k_indices = torch.topk(blended_logits, top_k)
42
+ filtered_logits = torch.full_like(blended_logits, float('-inf'))
43
+ filtered_logits.scatter_(1, top_k_indices, top_k_logits)
44
 
45
+ # Temperature scaling
46
+ scaled_logits = filtered_logits / temperature
47
+
48
+ probs = torch.softmax(scaled_logits, dim=-1)
49
+
50
+ next_token = torch.multinomial(probs, 1).item()
51
+
52
+ if next_token == tokenizer.eos_token_id:
53
  break
54
 
55
+ next_token_str = tokenizer.decode([next_token])
56
+ generated_text += next_token_str
57
 
58
  return generated_text
59
 
60
+
61
  with gr.Blocks() as demo:
62
+ system_prompt_a = gr.Textbox(label="System Prompt A", value="You are a helpful assistant.")
63
+ system_prompt_b = gr.Textbox(label="System Prompt B", value="You are a witty assistant.")
64
+ user_prompt = gr.Textbox(label="User Prompt", value="Tell me a story about a dragon.")
65
+ weight_a = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model A")
66
+ weight_b = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model B")
67
+ output_text = gr.Textbox(label="Output", lines=10)
68
 
69
  btn = gr.Button("Generate")
70
  btn.click(blend_generate, inputs=[system_prompt_a, system_prompt_b, user_prompt, weight_a, weight_b], outputs=output_text)