Jn-Huang
commited on
Commit
·
4cc1531
1
Parent(s):
fc3b3a2
Fix Gradio ChatInterface: remove lambda wrapper, add lazy loading, make public
Browse files
app.py
CHANGED
|
@@ -65,12 +65,22 @@ def load_model_and_tokenizer():
|
|
| 65 |
return base, tok
|
| 66 |
return base, tok
|
| 67 |
|
| 68 |
-
model
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
@spaces.GPU
|
| 72 |
@torch.inference_mode()
|
| 73 |
def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
|
|
|
|
|
|
|
|
|
|
| 74 |
# Apply Llama 3.1 chat template
|
| 75 |
prompt = tokenizer.apply_chat_template(
|
| 76 |
messages,
|
|
@@ -78,7 +88,7 @@ def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9)
|
|
| 78 |
add_generation_prompt=True
|
| 79 |
)
|
| 80 |
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
| 81 |
-
enc = {k: v.to(
|
| 82 |
|
| 83 |
input_length = enc['input_ids'].shape[1]
|
| 84 |
out = model.generate(
|
|
@@ -90,7 +100,8 @@ def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9)
|
|
| 90 |
pad_token_id=tokenizer.eos_token_id,
|
| 91 |
)
|
| 92 |
# Decode only the newly generated tokens
|
| 93 |
-
|
|
|
|
| 94 |
|
| 95 |
def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
|
| 96 |
# Build conversation in Llama 3.1 chat format
|
|
@@ -117,8 +128,7 @@ def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p)
|
|
| 117 |
return reply
|
| 118 |
|
| 119 |
demo = gr.ChatInterface(
|
| 120 |
-
fn=
|
| 121 |
-
chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p),
|
| 122 |
additional_inputs=[
|
| 123 |
gr.Textbox(label="System prompt (optional)", placeholder="You are Be.FM assistant...", lines=2),
|
| 124 |
gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"),
|
|
@@ -130,4 +140,4 @@ demo = gr.ChatInterface(
|
|
| 130 |
)
|
| 131 |
|
| 132 |
if __name__ == "__main__":
|
| 133 |
-
demo.launch()
|
|
|
|
| 65 |
return base, tok
|
| 66 |
return base, tok
|
| 67 |
|
| 68 |
+
# Lazy load model and tokenizer
|
| 69 |
+
_model = None
|
| 70 |
+
_tokenizer = None
|
| 71 |
+
|
| 72 |
+
def get_model_and_tokenizer():
|
| 73 |
+
global _model, _tokenizer
|
| 74 |
+
if _model is None:
|
| 75 |
+
_model, _tokenizer = load_model_and_tokenizer()
|
| 76 |
+
return _model, _tokenizer
|
| 77 |
|
| 78 |
@spaces.GPU
|
| 79 |
@torch.inference_mode()
|
| 80 |
def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
|
| 81 |
+
model, tokenizer = get_model_and_tokenizer()
|
| 82 |
+
device = model.device
|
| 83 |
+
|
| 84 |
# Apply Llama 3.1 chat template
|
| 85 |
prompt = tokenizer.apply_chat_template(
|
| 86 |
messages,
|
|
|
|
| 88 |
add_generation_prompt=True
|
| 89 |
)
|
| 90 |
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
| 91 |
+
enc = {k: v.to(device) for k, v in enc.items()}
|
| 92 |
|
| 93 |
input_length = enc['input_ids'].shape[1]
|
| 94 |
out = model.generate(
|
|
|
|
| 100 |
pad_token_id=tokenizer.eos_token_id,
|
| 101 |
)
|
| 102 |
# Decode only the newly generated tokens
|
| 103 |
+
generated_text = tokenizer.decode(out[0][input_length:], skip_special_tokens=True)
|
| 104 |
+
return generated_text.strip()
|
| 105 |
|
| 106 |
def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
|
| 107 |
# Build conversation in Llama 3.1 chat format
|
|
|
|
| 128 |
return reply
|
| 129 |
|
| 130 |
demo = gr.ChatInterface(
|
| 131 |
+
fn=chat_fn,
|
|
|
|
| 132 |
additional_inputs=[
|
| 133 |
gr.Textbox(label="System prompt (optional)", placeholder="You are Be.FM assistant...", lines=2),
|
| 134 |
gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"),
|
|
|
|
| 140 |
)
|
| 141 |
|
| 142 |
if __name__ == "__main__":
|
| 143 |
+
demo.launch(share=True)
|