broadfield-dev commited on
Commit
c57dbeb
·
verified ·
1 Parent(s): 9ac5ded

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -20
app.py CHANGED
@@ -11,8 +11,8 @@ SYSTEM_PROMPT = "You are a helpful and friendly AI assistant."
11
  # Log in using the secret token
12
  login(token=getenv("HF_TOKEN"))
13
 
14
- # Load Gemma 2B with optimizations for CPU
15
- model_name = "google/gemma-2b-270m"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
@@ -21,32 +21,38 @@ model = AutoModelForCausalLM.from_pretrained(
21
  device_map="cpu" # Explicitly map to CPU
22
  )
23
 
24
- # Simplify Gradio interface
25
  with gr.Blocks() as demo:
26
- gr.Markdown("# Gemma 2B Chatbot (CPU-Optimized)")
27
-
28
  with gr.Row():
29
  with gr.Column(scale=4):
30
- chatbot = gr.Chatbot(label="Chat")
31
  text_input = gr.Textbox(label="Your message")
32
  submit_button = gr.Button("Send")
33
-
34
  with gr.Column(scale=1):
35
- gr.Markdown("## Settings")
36
- max_length_slider = gr.Slider(minimum=20, maximum=512, value=100, label="Max New Tokens")
37
- temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature")
 
 
38
 
39
  def build_gemma_prompt(chat_history, new_message):
40
- # Simplified prompt in Gemma's expected format
41
- prompt = f"{SYSTEM_PROMPT}\n\n"
42
- for user_msg, model_msg in chat_history:
 
 
43
  prompt += f"<start_of_turn>user\n{user_msg}<end_of_turn>\n"
44
  if model_msg:
45
  prompt += f"<start_of_turn>model\n{model_msg}<end_of_turn>\n"
 
 
46
  prompt += f"<start_of_turn>user\n{new_message}<end_of_turn>\n<start_of_turn>model\n"
47
  return prompt
48
 
49
- def respond(message, chat_history, max_length, temperature):
50
  # Build prompt
51
  full_prompt = build_gemma_prompt(chat_history, message)
52
 
@@ -56,14 +62,14 @@ with gr.Blocks() as demo:
56
  # Update UI history
57
  chat_history.append((message, ""))
58
 
59
- # Initialize streamer with proper token skipping
60
  streamer = TextIteratorStreamer(
61
  tokenizer,
62
  skip_prompt=True,
63
  skip_special_tokens=True,
64
- clean_up_tokenization_spaces=True # Avoid gibberish from token artifacts
65
  )
66
-
67
  # Generation parameters
68
  generation_kwargs = {
69
  "input_ids": inputs["input_ids"],
@@ -71,14 +77,16 @@ with gr.Blocks() as demo:
71
  "streamer": streamer,
72
  "max_new_tokens": int(max_length),
73
  "temperature": float(temperature),
 
 
74
  "do_sample": True
75
  }
76
 
77
- # Run generation in a separate thread with no_grad
78
  with torch.no_grad():
79
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
80
  thread.start()
81
-
82
  # Stream response
83
  accumulated_text = ""
84
  for new_text in streamer:
@@ -88,7 +96,7 @@ with gr.Blocks() as demo:
88
 
89
  submit_button.click(
90
  respond,
91
- [text_input, chatbot, max_length_slider, temperature_slider],
92
  [text_input, chatbot]
93
  )
94
 
 
11
  # Log in using the secret token
12
  login(token=getenv("HF_TOKEN"))
13
 
14
+ # Load the specified model with CPU optimizations
15
+ model_name = "google/gemma-3-270m"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
 
21
  device_map="cpu" # Explicitly map to CPU
22
  )
23
 
24
+ # Gradio interface
25
  with gr.Blocks() as demo:
26
+ gr.Markdown("# Gemma 3 270M Chatbot (CPU-Optimized)")
27
+
28
  with gr.Row():
29
  with gr.Column(scale=4):
30
+ chatbot = gr.Chatbot(label="Gemma 3 Chat")
31
  text_input = gr.Textbox(label="Your message")
32
  submit_button = gr.Button("Send")
33
+
34
  with gr.Column(scale=1):
35
+ gr.Markdown("## User Controls")
36
+ max_length_slider = gr.Slider(minimum=20, maximum=512, value=100, label="Max New Tokens") # Reduced max for CPU
37
+ temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
38
+ top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p")
39
+ top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
40
 
41
  def build_gemma_prompt(chat_history, new_message):
42
+ # Simplified prompt construction in Gemma format
43
+ prompt = ""
44
+ for i, (user_msg, model_msg) in enumerate(chat_history):
45
+ if i == 0:
46
+ user_msg = f"{SYSTEM_PROMPT}\n\n{user_msg}"
47
  prompt += f"<start_of_turn>user\n{user_msg}<end_of_turn>\n"
48
  if model_msg:
49
  prompt += f"<start_of_turn>model\n{model_msg}<end_of_turn>\n"
50
+ if not chat_history:
51
+ new_message = f"{SYSTEM_PROMPT}\n\n{new_message}"
52
  prompt += f"<start_of_turn>user\n{new_message}<end_of_turn>\n<start_of_turn>model\n"
53
  return prompt
54
 
55
+ def respond(message, chat_history, max_length, temperature, top_p, top_k):
56
  # Build prompt
57
  full_prompt = build_gemma_prompt(chat_history, message)
58
 
 
62
  # Update UI history
63
  chat_history.append((message, ""))
64
 
65
+ # Initialize streamer with proper token handling
66
  streamer = TextIteratorStreamer(
67
  tokenizer,
68
  skip_prompt=True,
69
  skip_special_tokens=True,
70
+ clean_up_tokenization_spaces=True # Prevent token artifacts
71
  )
72
+
73
  # Generation parameters
74
  generation_kwargs = {
75
  "input_ids": inputs["input_ids"],
 
77
  "streamer": streamer,
78
  "max_new_tokens": int(max_length),
79
  "temperature": float(temperature),
80
+ "top_p": float(top_p),
81
+ "top_k": int(top_k),
82
  "do_sample": True
83
  }
84
 
85
+ # Run generation with no_grad for memory efficiency
86
  with torch.no_grad():
87
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
88
  thread.start()
89
+
90
  # Stream response
91
  accumulated_text = ""
92
  for new_text in streamer:
 
96
 
97
  submit_button.click(
98
  respond,
99
+ [text_input, chatbot, max_length_slider, temperature_slider, top_p_slider, top_k_slider],
100
  [text_input, chatbot]
101
  )
102