arshadrana commited on
Commit
d7968e8
·
verified ·
1 Parent(s): d20da4a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
+
5
+ # Load the model and tokenizer
6
+ model_name = "Qwen/Qwen2-Math-1.5B-Instruct"
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ model_name,
11
+ torch_dtype="auto",
12
+ device_map="auto"
13
+ ).to(device)
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+
17
+ # Define a function for Gradio to handle user input
18
+ def solve_math(prompt):
19
+ messages = [
20
+ {"role": "system", "content": "You are a helpful assistant."},
21
+ {"role": "user", "content": prompt}
22
+ ]
23
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
24
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
25
+
26
+ generation_config = GenerationConfig(
27
+ do_sample=False, # For greedy decoding
28
+ max_new_tokens=512
29
+ )
30
+
31
+ generated_ids = model.generate(
32
+ **model_inputs,
33
+ generation_config=generation_config
34
+ )
35
+
36
+ # Remove the input tokens from the output
37
+ generated_ids = [
38
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
39
+ ]
40
+
41
+ # Decode the generated output and return the result
42
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
43
+ return response
44
+
45
+ # Create the Gradio interface
46
+ iface = gr.Interface(
47
+ fn=solve_math, # Function to call
48
+ inputs="text", # Text input for the user prompt
49
+ outputs="text", # Text output for the model's response
50
+ title="Math Solver", # App title
51
+ description="Provide a math problem and the model will solve it."
52
+ )
53
+
54
+ # Launch the app
55
+ if __name__ == "__main__":
56
+ iface.launch()