Scaryscar commited on
Commit
c99d5db
·
verified ·
1 Parent(s): 10c007d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -11
app.py CHANGED
@@ -1,39 +1,75 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
  from huggingface_hub import login
3
  import torch
4
  import gradio as gr
5
  import os
6
 
7
- # Authenticate using HF_TOKEN from Space secrets
8
  login(token=os.environ.get("HF_TOKEN"))
9
 
10
  # Configuration
11
  MODEL_NAME = "google/gemma-2b-it"
12
  CACHE_DIR = "/tmp"
13
 
14
- # Load model with authentication
 
 
 
 
 
 
 
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_NAME,
 
18
  device_map="auto",
19
  torch_dtype=torch.float16,
20
  cache_dir=CACHE_DIR
21
  )
22
 
23
  def solve_math(question):
24
- prompt = f"Question: {question}\nAnswer:"
 
 
 
 
25
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
26
- outputs = model.generate(**inputs, max_new_tokens=200)
 
 
 
 
 
27
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
28
 
29
  # Gradio Interface
30
- with gr.Blocks() as demo:
31
- gr.Markdown("## Gemma-2B Math Solver")
 
 
 
 
 
 
32
  with gr.Row():
33
- question = gr.Textbox(label="Math Problem", placeholder="Enter your question here...")
34
  with gr.Row():
35
- answer = gr.Textbox(label="Solution", interactive=False)
36
- question.submit(fn=solve_math, inputs=question, outputs=answer)
 
 
 
 
 
 
 
 
 
37
 
38
  if __name__ == "__main__":
39
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
2
  from huggingface_hub import login
3
  import torch
4
  import gradio as gr
5
  import os
6
 
7
+ # Authenticate with Hugging Face
8
  login(token=os.environ.get("HF_TOKEN"))
9
 
10
  # Configuration
11
  MODEL_NAME = "google/gemma-2b-it"
12
  CACHE_DIR = "/tmp"
13
 
14
+ # 4-bit quantization config
15
+ quant_config = BitsAndBytesConfig(
16
+ load_in_4bit=True,
17
+ bnb_4bit_compute_dtype=torch.float16,
18
+ bnb_4bit_quant_type="nf4"
19
+ )
20
+
21
+ # Load model with optimizations
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
23
  model = AutoModelForCausalLM.from_pretrained(
24
  MODEL_NAME,
25
+ quantization_config=quant_config,
26
  device_map="auto",
27
  torch_dtype=torch.float16,
28
  cache_dir=CACHE_DIR
29
  )
30
 
31
  def solve_math(question):
32
+ prompt = f"""Solve this math problem step by step:
33
+
34
+ Question: {question}
35
+ Answer:"""
36
+
37
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
38
+ outputs = model.generate(
39
+ **inputs,
40
+ max_new_tokens=256,
41
+ temperature=0.7,
42
+ do_sample=True
43
+ )
44
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
45
 
46
  # Gradio Interface
47
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
48
+ gr.Markdown("""<h1><center>🧮 Gemma-2B Math Solver</center></h1>""")
49
+ with gr.Row():
50
+ question = gr.Textbox(
51
+ label="Enter your math problem",
52
+ placeholder="e.g., What is the derivative of x^2?",
53
+ lines=3
54
+ )
55
  with gr.Row():
56
+ submit_btn = gr.Button("Solve", variant="primary")
57
  with gr.Row():
58
+ answer = gr.Textbox(
59
+ label="Solution",
60
+ lines=5,
61
+ interactive=False
62
+ )
63
+
64
+ submit_btn.click(
65
+ fn=solve_math,
66
+ inputs=question,
67
+ outputs=answer
68
+ )
69
 
70
  if __name__ == "__main__":
71
+ demo.launch(
72
+ server_name="0.0.0.0",
73
+ server_port=7860,
74
+ share=False
75
+ )