Amossofer commited on
Commit
51d9d55
·
1 Parent(s): 38b5252
Files changed (1) hide show
  1. app.py +30 -21
app.py CHANGED
@@ -6,8 +6,9 @@ import gradio as gr
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
  # Load two small models and their tokenizer (you can replace these with your models)
9
- model_name_a = "distilgpt2"
10
- model_name_b = "sshleifer/tiny-gpt2" # very small GPT2 variant for demo
 
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(model_name_a)
13
 
@@ -17,31 +18,39 @@ model_b = AutoModelForCausalLM.from_pretrained(model_name_b).to(device)
17
  model_a.eval()
18
  model_b.eval()
19
 
20
- def blend_generate(system_prompt_a, system_prompt_b, user_prompt, wa, wb):
21
- # Combine system prompt A + user prompt for model A
22
- prompt_a = system_prompt_a + user_prompt
23
- # Combine system prompt B + user prompt for model B
24
- prompt_b = system_prompt_b + user_prompt
25
 
26
- input_ids_a = tokenizer(prompt_a, return_tensors="pt").input_ids.to(device)
27
- input_ids_b = tokenizer(prompt_b, return_tensors="pt").input_ids.to(device)
 
 
28
 
29
- with torch.no_grad():
30
- output_a = model_a(input_ids_a)
31
- output_b = model_b(input_ids_b)
32
 
33
- logits_a = output_a.logits[:, -1, :]
34
- logits_b = output_b.logits[:, -1, :]
 
35
 
36
- blended_logits = wa * logits_a + wb * logits_b
 
37
 
38
- probs = torch.softmax(blended_logits, dim=-1)
39
- token = torch.multinomial(probs, 1)
40
- next_token_id = token.item()
41
- next_token = tokenizer.decode([next_token_id])
42
 
43
- # For simplicity, just return user prompt + next token (you can customize)
44
- return user_prompt + next_token
 
 
 
 
 
 
 
 
 
 
45
 
46
  with gr.Blocks() as demo:
47
  system_prompt_a = gr.Textbox(label="System Prompt A", value="You are a helpful assistant. ")
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
  # Load two small models and their tokenizer (you can replace these with your models)
9
+ MODEL_NAME = "arnir0/Tiny-LLM"
10
+ model_name_a = MODEL_NAME
11
+ model_name_b = MODEL_NAME
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_name_a)
14
 
 
18
  model_a.eval()
19
  model_b.eval()
20
 
21
+ def blend_generate(system_prompt_a, system_prompt_b, user_prompt, wa, wb, max_length=50):
22
+ generated_text = user_prompt
23
+ device = next(model_a.parameters()).device # infer device from model
 
 
24
 
25
+ for _ in range(max_length):
26
+ # Prepare prompts for each model: system prompt + generated text so far
27
+ prompt_a = system_prompt_a + generated_text
28
+ prompt_b = system_prompt_b + generated_text
29
 
30
+ input_ids_a = tokenizer(prompt_a, return_tensors="pt").input_ids.to(device)
31
+ input_ids_b = tokenizer(prompt_b, return_tensors="pt").input_ids.to(device)
 
32
 
33
+ with torch.no_grad():
34
+ output_a = model_a(input_ids_a)
35
+ output_b = model_b(input_ids_b)
36
 
37
+ logits_a = output_a.logits[:, -1, :]
38
+ logits_b = output_b.logits[:, -1, :]
39
 
40
+ blended_logits = wa * logits_a + wb * logits_b
 
 
 
41
 
42
+ probs = torch.softmax(blended_logits, dim=-1)
43
+ token = torch.multinomial(probs, 1)
44
+ next_token_id = token.item()
45
+
46
+ # Stop if end-of-sequence token generated (adjust based on your tokenizer)
47
+ if next_token_id == tokenizer.eos_token_id:
48
+ break
49
+
50
+ next_token = tokenizer.decode([next_token_id])
51
+ generated_text += next_token
52
+
53
+ return generated_text
54
 
55
  with gr.Blocks() as demo:
56
  system_prompt_a = gr.Textbox(label="System Prompt A", value="You are a helpful assistant. ")