lllouo commited on
Commit
31cbe45
·
1 Parent(s): e2221c7
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -154,11 +154,22 @@ Corrected:"""
154
  # 生成修正
155
  inputs = gec_tokenizer(prompt, return_tensors="pt").to(gec_model.device)
156
 
 
 
 
 
 
 
 
 
 
 
 
157
  with torch.no_grad():
158
  outputs = gec_model.generate(
159
  **inputs,
160
- max_new_tokens=512,
161
- num_beams=4,
162
  do_sample=False,
163
  temperature=None,
164
  top_p=None
@@ -847,6 +858,11 @@ with demo:
847
  )
848
 
849
  if __name__ == "__main__":
 
 
 
 
 
850
  demo.launch(
851
  server_name="0.0.0.0",
852
  server_port=7860,
 
154
  # 生成修正
155
  inputs = gec_tokenizer(prompt, return_tensors="pt").to(gec_model.device)
156
 
157
+ # 检测设备类型以优化参数
158
+ is_cpu = str(gec_model.device) == "cpu" or not torch.cuda.is_available()
159
+
160
+ # CPU优化参数:减少beam search和token长度
161
+ if is_cpu:
162
+ max_tokens = 256 # CPU模式减半
163
+ beams = 2 # 减少beam数量加速
164
+ else:
165
+ max_tokens = 512 # GPU模式保持
166
+ beams = 4
167
+
168
  with torch.no_grad():
169
  outputs = gec_model.generate(
170
  **inputs,
171
+ max_new_tokens=max_tokens,
172
+ num_beams=beams,
173
  do_sample=False,
174
  temperature=None,
175
  top_p=None
 
858
  )
859
 
860
  if __name__ == "__main__":
861
+ # 可选:预加载模型(会增加启动时间)
862
+ # 如果想要预加载,取消下面两行的注释
863
+ print("🚀 预加载WAC-GEC模型...")
864
+ initialize_wac_gec()
865
+
866
  demo.launch(
867
  server_name="0.0.0.0",
868
  server_port=7860,