Amossofer commited on
Commit
0d6a629
·
1 Parent(s): 5d3ecec
Files changed (1) hide show
  1. app.py +22 -59
app.py CHANGED
@@ -1,72 +1,35 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import gradio as gr
4
 
5
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
 
7
- model_name_a = "mistralai/Mistral-7B-v0.1"
8
- model_name_b = "mistralai/Mistral-7B-v0.1" # you can replace this with a second different model or finetuned variant
9
 
10
- print("Loading tokenizer...")
11
- tokenizer = AutoTokenizer.from_pretrained(model_name_a)
12
 
13
- print("Loading models...")
14
- model_a = AutoModelForCausalLM.from_pretrained(model_name_a, device_map="auto", torch_dtype=torch.float16)
15
- model_b = AutoModelForCausalLM.from_pretrained(model_name_b, device_map="auto", torch_dtype=torch.float16)
16
 
17
- model_a.eval()
18
- model_b.eval()
 
 
 
 
 
 
19
 
20
- def blend_generate(system_prompt_a, system_prompt_b, user_prompt, wa, wb, max_length=50, temperature=0.7, top_k=50):
21
- device = next(model_a.parameters()).device
22
- generated_text = user_prompt
23
-
24
- for _ in range(max_length):
25
- prompt_a = system_prompt_a.strip() + "\n" + generated_text
26
- prompt_b = system_prompt_b.strip() + "\n" + generated_text
27
-
28
- input_ids_a = tokenizer(prompt_a, return_tensors="pt").input_ids.to(device)
29
- input_ids_b = tokenizer(prompt_b, return_tensors="pt").input_ids.to(device)
30
-
31
- with torch.no_grad():
32
- output_a = model_a(input_ids_a)
33
- output_b = model_b(input_ids_b)
34
-
35
- logits_a = output_a.logits[:, -1, :]
36
- logits_b = output_b.logits[:, -1, :]
37
-
38
- blended_logits = wa * logits_a + wb * logits_b
39
-
40
- # Apply top-k filtering
41
- top_k_logits, top_k_indices = torch.topk(blended_logits, top_k)
42
- filtered_logits = torch.full_like(blended_logits, float('-inf'))
43
- filtered_logits.scatter_(1, top_k_indices, top_k_logits)
44
-
45
- # Temperature scaling
46
- scaled_logits = filtered_logits / temperature
47
-
48
- probs = torch.softmax(scaled_logits, dim=-1)
49
-
50
- next_token = torch.multinomial(probs, 1).item()
51
-
52
- if next_token == tokenizer.eos_token_id:
53
- break
54
-
55
- next_token_str = tokenizer.decode([next_token])
56
- generated_text += next_token_str
57
 
 
58
  return generated_text
59
 
 
 
 
60
 
61
- with gr.Blocks() as demo:
62
- system_prompt_a = gr.Textbox(label="System Prompt A", value="You are a helpful assistant.")
63
- system_prompt_b = gr.Textbox(label="System Prompt B", value="You are a witty assistant.")
64
- user_prompt = gr.Textbox(label="User Prompt", value="Tell me a story about a dragon.")
65
- weight_a = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model A")
66
- weight_b = gr.Slider(minimum=0, maximum=1, value=0.5, label="Weight Model B")
67
- output_text = gr.Textbox(label="Output", lines=10)
68
 
69
- btn = gr.Button("Generate")
70
- btn.click(blend_generate, inputs=[system_prompt_a, system_prompt_b, user_prompt, weight_a, weight_b], outputs=output_text)
71
 
72
- demo.launch()
 
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
 
 
4
 
5
+ MODEL_NAME = "microsoft/Phi-4-mini-instruct"
 
6
 
7
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
9
 
10
+ def generate_text(prompt, model, tokenizer, max_length=512, temperature=1, top_k=50, top_p=0.95):
11
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
 
12
 
13
+ outputs = model.generate(
14
+ inputs,
15
+ max_length=max_length,
16
+ temperature=temperature,
17
+ top_k=top_k,
18
+ top_p=top_p,
19
+ do_sample=True
20
+ )
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
  return generated_text
25
 
26
+ def main():
27
+ # Define your prompt
28
+ prompt = "According to all known laws of aviation, there is no way a bee should be able to fly."
29
 
30
+ generated_text = generate_text(prompt, model, tokenizer)
 
 
 
 
 
 
31
 
32
+ print(generated_text)
 
33
 
34
+ if __name__ == "__main__":
35
+ main()