Wenye He commited on
Commit
3a4c40c
·
verified ·
1 Parent(s): f937954

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -12
app.py CHANGED
@@ -56,18 +56,31 @@ class ChatModel:
56
  # Format prompt
57
  prompt = config["template"].format(message=message)
58
 
59
- # Tokenize input
60
- inputs = self.tokenizers[model_name](prompt, return_tensors="pt").to("cuda")
 
 
 
 
 
61
 
62
- # Generate response
63
- outputs = self.models[model_name].generate(
64
- **inputs,
65
- max_new_tokens=384,
66
- temperature=0.7,
67
- top_p=0.9,
68
- do_sample=True,
69
- pad_token_id=self.tokenizers[model_name].eos_token_id
70
- )
 
 
 
 
 
 
 
 
71
 
72
  # Decode response
73
  response = self.tokenizers[model_name].decode(
@@ -77,7 +90,7 @@ class ChatModel:
77
 
78
  # Calculate metrics
79
  elapsed_time = time.time() - start_time
80
- tokens = outputs[0].shape[0] - inputs.input_ids.shape[-1]
81
  tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
82
 
83
  return response, elapsed_time, tokens_per_sec
 
56
  # Format prompt
57
  prompt = config["template"].format(message=message)
58
 
59
+ # Tokenize input with proper max_length handling
60
+ inputs = self.tokenizers[model_name](
61
+ prompt,
62
+ return_tensors="pt",
63
+ max_length=2048,
64
+ truncation=True
65
+ ).to("cuda")
66
 
67
+ # Generation parameters
68
+ generation_kwargs = {
69
+ "inputs": inputs.input_ids,
70
+ "max_new_tokens": 384,
71
+ "temperature": 0.7,
72
+ "top_p": 0.9,
73
+ "do_sample": True,
74
+ "pad_token_id": self.tokenizers[model_name].eos_token_id
75
+ }
76
+
77
+ # Phi-3 specific workaround
78
+ if "phi-3" in model_name:
79
+ generation_kwargs["attention_mask"] = inputs.attention_mask
80
+ generation_kwargs.pop("inputs")
81
+ generation_kwargs["input_ids"] = inputs.input_ids
82
+
83
+ outputs = self.models[model_name].generate(**generation_kwargs)
84
 
85
  # Decode response
86
  response = self.tokenizers[model_name].decode(
 
90
 
91
  # Calculate metrics
92
  elapsed_time = time.time() - start_time
93
+ tokens = outputs[0].shape[-1] - inputs.input_ids.shape[-1]
94
  tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
95
 
96
  return response, elapsed_time, tokens_per_sec