Scaryscar commited on
Commit
9db5cfb
·
verified ·
1 Parent(s): 9664295

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -18
app.py CHANGED
@@ -4,21 +4,22 @@ import torch
4
  import gradio as gr
5
  import os
6
 
7
- # Authenticate
8
- login(token=os.environ.get("HF_TOKEN"))
9
-
10
  # Configuration
11
  MODEL_NAME = "google/gemma-2b-it"
12
  CACHE_DIR = "/tmp"
 
 
 
 
13
 
14
- # 4-bit quantization config
15
  quant_config = BitsAndBytesConfig(
16
  load_in_4bit=True,
17
  bnb_4bit_compute_dtype=torch.float16,
18
  bnb_4bit_quant_type="nf4"
19
  )
20
 
21
- # Load model
22
  try:
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
24
  model = AutoModelForCausalLM.from_pretrained(
@@ -29,44 +30,67 @@ try:
29
  cache_dir=CACHE_DIR
30
  )
31
  except Exception as e:
32
- raise gr.Error(f"Model loading failed: {str(e)}")
33
 
34
  def solve_math(question):
 
35
  try:
36
- prompt = f"Solve step by step: {question}\nAnswer:"
37
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
38
  outputs = model.generate(
39
  **inputs,
40
- max_new_tokens=256,
41
- temperature=0.3,
42
- do_sample=True
 
43
  )
44
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
45
  except Exception as e:
46
- return f"Error: {str(e)}"
 
 
 
47
 
48
  # Gradio Interface
49
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
50
- gr.Markdown("""<h1><center>🧮 Gemma-2B Math Solver</center></h1>""")
 
51
  with gr.Row():
52
  question = gr.Textbox(
53
- label="Math Problem",
54
- placeholder="What is 2^10 + 5*3?",
55
  lines=3
56
  )
 
57
  with gr.Row():
58
  submit_btn = gr.Button("Solve", variant="primary")
 
59
  with gr.Row():
60
  answer = gr.Textbox(
61
- label="Solution",
62
- lines=5,
63
  interactive=False
64
  )
65
 
 
 
 
 
 
 
 
 
 
 
66
  submit_btn.click(
67
  fn=solve_math,
68
  inputs=question,
69
- outputs=answer
 
70
  )
71
 
72
  if __name__ == "__main__":
 
4
  import gradio as gr
5
  import os
6
 
 
 
 
7
  # Configuration
8
  MODEL_NAME = "google/gemma-2b-it"
9
  CACHE_DIR = "/tmp"
10
+ MAX_TOKENS = 200 # Reduced for faster responses
11
+
12
+ # Authenticate (HF_TOKEN must be set in Space secrets)
13
+ login(token=os.environ.get("HF_TOKEN"))
14
 
15
+ # 4-bit quantization for memory efficiency
16
  quant_config = BitsAndBytesConfig(
17
  load_in_4bit=True,
18
  bnb_4bit_compute_dtype=torch.float16,
19
  bnb_4bit_quant_type="nf4"
20
  )
21
 
22
+ # Load model with error handling
23
  try:
24
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
25
  model = AutoModelForCausalLM.from_pretrained(
 
30
  cache_dir=CACHE_DIR
31
  )
32
  except Exception as e:
33
+ raise gr.Error(f"⚠️ Model loading failed. Please check your token and try again.\nError: {str(e)}")
34
 
35
  def solve_math(question):
36
+ """Generate step-by-step solutions with error handling"""
37
  try:
38
+ prompt = f"Solve this step by step:\n\nQuestion: {question}\nAnswer:"
39
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
40
+
41
  outputs = model.generate(
42
  **inputs,
43
+ max_new_tokens=MAX_TOKENS,
44
+ temperature=0.3, # Lower = more deterministic answers
45
+ do_sample=True,
46
+ pad_token_id=tokenizer.eos_token_id
47
  )
48
+
49
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+ return answer.split("Answer:")[-1].strip()
51
+
52
  except Exception as e:
53
+ return f"Error generating answer: {str(e)}"
54
+
55
+ # Preload model for faster first response
56
+ solve_math("2+2=") # Warm-up call
57
 
58
  # Gradio Interface
59
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
60
+ gr.Markdown("""<h1><center> Gemma-2B Math Solver</center></h1>""")
61
+
62
  with gr.Row():
63
  question = gr.Textbox(
64
+ label="Enter your math problem",
65
+ placeholder="What is the integral of x^2 from 0 to 3?",
66
  lines=3
67
  )
68
+
69
  with gr.Row():
70
  submit_btn = gr.Button("Solve", variant="primary")
71
+
72
  with gr.Row():
73
  answer = gr.Textbox(
74
+ label="Step-by-step solution",
75
+ lines=6,
76
  interactive=False
77
  )
78
 
79
+ # Examples for quick testing
80
+ gr.Examples(
81
+ examples=[
82
+ ["What is 2^10 + 5*3?"],
83
+ ["Solve for x: 3x + 5 = 20"],
84
+ ["Calculate the area of a circle with radius 4"]
85
+ ],
86
+ inputs=question
87
+ )
88
+
89
  submit_btn.click(
90
  fn=solve_math,
91
  inputs=question,
92
+ outputs=answer,
93
+ api_name="solve"
94
  )
95
 
96
  if __name__ == "__main__":