Spaces:
Paused
Paused
Wenye He
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -56,18 +56,31 @@ class ChatModel:
|
|
| 56 |
# Format prompt
|
| 57 |
prompt = config["template"].format(message=message)
|
| 58 |
|
| 59 |
-
# Tokenize input
|
| 60 |
-
inputs = self.tokenizers[model_name](
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
#
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
max_new_tokens
|
| 66 |
-
temperature
|
| 67 |
-
top_p
|
| 68 |
-
do_sample
|
| 69 |
-
pad_token_id
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# Decode response
|
| 73 |
response = self.tokenizers[model_name].decode(
|
|
@@ -77,7 +90,7 @@ class ChatModel:
|
|
| 77 |
|
| 78 |
# Calculate metrics
|
| 79 |
elapsed_time = time.time() - start_time
|
| 80 |
-
tokens = outputs[0].shape[
|
| 81 |
tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
|
| 82 |
|
| 83 |
return response, elapsed_time, tokens_per_sec
|
|
|
|
| 56 |
# Format prompt
|
| 57 |
prompt = config["template"].format(message=message)
|
| 58 |
|
| 59 |
+
# Tokenize input with proper max_length handling
|
| 60 |
+
inputs = self.tokenizers[model_name](
|
| 61 |
+
prompt,
|
| 62 |
+
return_tensors="pt",
|
| 63 |
+
max_length=2048,
|
| 64 |
+
truncation=True
|
| 65 |
+
).to("cuda")
|
| 66 |
|
| 67 |
+
# Generation parameters
|
| 68 |
+
generation_kwargs = {
|
| 69 |
+
"inputs": inputs.input_ids,
|
| 70 |
+
"max_new_tokens": 384,
|
| 71 |
+
"temperature": 0.7,
|
| 72 |
+
"top_p": 0.9,
|
| 73 |
+
"do_sample": True,
|
| 74 |
+
"pad_token_id": self.tokenizers[model_name].eos_token_id
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
# Phi-3 specific workaround
|
| 78 |
+
if "phi-3" in model_name:
|
| 79 |
+
generation_kwargs["attention_mask"] = inputs.attention_mask
|
| 80 |
+
generation_kwargs.pop("inputs")
|
| 81 |
+
generation_kwargs["input_ids"] = inputs.input_ids
|
| 82 |
+
|
| 83 |
+
outputs = self.models[model_name].generate(**generation_kwargs)
|
| 84 |
|
| 85 |
# Decode response
|
| 86 |
response = self.tokenizers[model_name].decode(
|
|
|
|
| 90 |
|
| 91 |
# Calculate metrics
|
| 92 |
elapsed_time = time.time() - start_time
|
| 93 |
+
tokens = outputs[0].shape[-1] - inputs.input_ids.shape[-1]
|
| 94 |
tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
|
| 95 |
|
| 96 |
return response, elapsed_time, tokens_per_sec
|