Jn-Huang commited on
Commit
4cc1531
·
1 Parent(s): fc3b3a2

Fix Gradio ChatInterface: remove lambda wrapper, add lazy loading, make public

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -65,12 +65,22 @@ def load_model_and_tokenizer():
65
  return base, tok
66
  return base, tok
67
 
68
- model, tokenizer = load_model_and_tokenizer()
69
- DEVICE = model.device
 
 
 
 
 
 
 
70
 
71
  @spaces.GPU
72
  @torch.inference_mode()
73
  def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
 
 
 
74
  # Apply Llama 3.1 chat template
75
  prompt = tokenizer.apply_chat_template(
76
  messages,
@@ -78,7 +88,7 @@ def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9)
78
  add_generation_prompt=True
79
  )
80
  enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
81
- enc = {k: v.to(DEVICE) for k, v in enc.items()}
82
 
83
  input_length = enc['input_ids'].shape[1]
84
  out = model.generate(
@@ -90,7 +100,8 @@ def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9)
90
  pad_token_id=tokenizer.eos_token_id,
91
  )
92
  # Decode only the newly generated tokens
93
- return tokenizer.decode(out[0][input_length:], skip_special_tokens=True)
 
94
 
95
  def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
96
  # Build conversation in Llama 3.1 chat format
@@ -117,8 +128,7 @@ def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p)
117
  return reply
118
 
119
  demo = gr.ChatInterface(
120
- fn=lambda message, history, system_prompt, max_new_tokens, temperature, top_p:
121
- chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p),
122
  additional_inputs=[
123
  gr.Textbox(label="System prompt (optional)", placeholder="You are Be.FM assistant...", lines=2),
124
  gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"),
@@ -130,4 +140,4 @@ demo = gr.ChatInterface(
130
  )
131
 
132
  if __name__ == "__main__":
133
- demo.launch()
 
65
  return base, tok
66
  return base, tok
67
 
68
+ # Lazy load model and tokenizer
69
+ _model = None
70
+ _tokenizer = None
71
+
72
+ def get_model_and_tokenizer():
73
+ global _model, _tokenizer
74
+ if _model is None:
75
+ _model, _tokenizer = load_model_and_tokenizer()
76
+ return _model, _tokenizer
77
 
78
  @spaces.GPU
79
  @torch.inference_mode()
80
  def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
81
+ model, tokenizer = get_model_and_tokenizer()
82
+ device = model.device
83
+
84
  # Apply Llama 3.1 chat template
85
  prompt = tokenizer.apply_chat_template(
86
  messages,
 
88
  add_generation_prompt=True
89
  )
90
  enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
91
+ enc = {k: v.to(device) for k, v in enc.items()}
92
 
93
  input_length = enc['input_ids'].shape[1]
94
  out = model.generate(
 
100
  pad_token_id=tokenizer.eos_token_id,
101
  )
102
  # Decode only the newly generated tokens
103
+ generated_text = tokenizer.decode(out[0][input_length:], skip_special_tokens=True)
104
+ return generated_text.strip()
105
 
106
  def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
107
  # Build conversation in Llama 3.1 chat format
 
128
  return reply
129
 
130
  demo = gr.ChatInterface(
131
+ fn=chat_fn,
 
132
  additional_inputs=[
133
  gr.Textbox(label="System prompt (optional)", placeholder="You are Be.FM assistant...", lines=2),
134
  gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"),
 
140
  )
141
 
142
  if __name__ == "__main__":
143
+ demo.launch(share=True)