Scaryscar commited on
Commit
9664295
·
verified ·
1 Parent(s): 0ff169e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -27
app.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  import gradio as gr
5
  import os
6
 
7
- # Authenticate with Hugging Face
8
  login(token=os.environ.get("HF_TOKEN"))
9
 
10
  # Configuration
@@ -18,38 +18,40 @@ quant_config = BitsAndBytesConfig(
18
  bnb_4bit_quant_type="nf4"
19
  )
20
 
21
- # Load model with optimizations
22
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
23
- model = AutoModelForCausalLM.from_pretrained(
24
- MODEL_NAME,
25
- quantization_config=quant_config,
26
- device_map="auto",
27
- torch_dtype=torch.float16,
28
- cache_dir=CACHE_DIR
29
- )
 
 
 
30
 
31
  def solve_math(question):
32
- prompt = f"""Solve this math problem step by step:
33
-
34
- Question: {question}
35
- Answer:"""
36
-
37
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
38
- outputs = model.generate(
39
- **inputs,
40
- max_new_tokens=256,
41
- temperature=0.7,
42
- do_sample=True
43
- )
44
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
45
 
46
  # Gradio Interface
47
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
48
  gr.Markdown("""<h1><center>🧮 Gemma-2B Math Solver</center></h1>""")
49
  with gr.Row():
50
  question = gr.Textbox(
51
- label="Enter your math problem",
52
- placeholder="e.g., What is the derivative of x^2?",
53
  lines=3
54
  )
55
  with gr.Row():
@@ -70,6 +72,5 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
70
  if __name__ == "__main__":
71
  demo.launch(
72
  server_name="0.0.0.0",
73
- server_port=7860,
74
- share=False
75
  )
 
4
  import gradio as gr
5
  import os
6
 
7
+ # Authenticate
8
  login(token=os.environ.get("HF_TOKEN"))
9
 
10
  # Configuration
 
18
  bnb_4bit_quant_type="nf4"
19
  )
20
 
21
+ # Load model
22
+ try:
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ MODEL_NAME,
26
+ quantization_config=quant_config,
27
+ device_map="auto",
28
+ torch_dtype=torch.float16,
29
+ cache_dir=CACHE_DIR
30
+ )
31
+ except Exception as e:
32
+ raise gr.Error(f"Model loading failed: {str(e)}")
33
 
34
  def solve_math(question):
35
+ try:
36
+ prompt = f"Solve step by step: {question}\nAnswer:"
37
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
38
+ outputs = model.generate(
39
+ **inputs,
40
+ max_new_tokens=256,
41
+ temperature=0.3,
42
+ do_sample=True
43
+ )
44
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ except Exception as e:
46
+ return f"Error: {str(e)}"
 
47
 
48
  # Gradio Interface
49
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
50
  gr.Markdown("""<h1><center>🧮 Gemma-2B Math Solver</center></h1>""")
51
  with gr.Row():
52
  question = gr.Textbox(
53
+ label="Math Problem",
54
+ placeholder="What is 2^10 + 5*3?",
55
  lines=3
56
  )
57
  with gr.Row():
 
72
  if __name__ == "__main__":
73
  demo.launch(
74
  server_name="0.0.0.0",
75
+ server_port=7860
 
76
  )