czjun commited on
Commit
897c2d5
·
1 Parent(s): 9704503

feat: 添加错误处理和模型评估,优化摘要生成逻辑

Browse files
Files changed (3) hide show
  1. README.md +4 -0
  2. __pycache__/app.cpython-310.pyc +0 -0
  3. app.py +34 -19
README.md CHANGED
@@ -9,3 +9,7 @@ license: mit
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+
13
+ To force a specific transformer model in Spaces, set the `MODEL_NAME` environment variable, for example:
14
+
15
+ `IDEA-CCNL/Randeng-T5-Char-57M-MultiTask-Chinese`
__pycache__/app.cpython-310.pyc CHANGED
Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ
 
app.py CHANGED
@@ -26,6 +26,7 @@ class SummaryOutput:
26
  summary: str
27
  backend: str
28
  used_target_length: Optional[int]
 
29
 
30
 
31
  class SummarizationConfig:
@@ -109,6 +110,7 @@ class HybridSummarizer:
109
  self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
110
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
111
  self.model.to(self.device)
 
112
  self.backend_name = "transformer"
113
  self.load_error = None
114
  except Exception as exc:
@@ -129,8 +131,14 @@ class HybridSummarizer:
129
  backend="transformer",
130
  used_target_length=target_length,
131
  )
132
- except Exception:
133
- pass
 
 
 
 
 
 
134
  return SummaryOutput(
135
  summary=self.fallback.summarize(text, target_length=target_length),
136
  backend="fallback",
@@ -138,26 +146,29 @@ class HybridSummarizer:
138
  )
139
 
140
  def _summarize_with_transformer(self, text: str, target_length: int | None) -> str:
141
- prompt = f"summarize: {text}"
142
  inputs = self.tokenizer(
143
  prompt,
144
  return_tensors="pt",
145
  truncation=True,
146
- max_length=SummarizationConfig.max_source_length,
147
  )
148
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
149
- max_new_tokens = max(32, min(256, int((target_length or 120) * 1.2)))
150
- min_new_tokens = max(16, int(max_new_tokens * 0.4))
151
- generated = self.model.generate(
152
- **inputs,
153
- max_new_tokens=max_new_tokens,
154
- min_new_tokens=min_new_tokens,
155
- num_beams=SummarizationConfig.num_beams,
156
- no_repeat_ngram_size=SummarizationConfig.no_repeat_ngram_size,
157
- length_penalty=SummarizationConfig.length_penalty,
158
- early_stopping=True,
159
- )
160
- return self.tokenizer.decode(generated[0], skip_special_tokens=True).strip()
 
 
 
161
 
162
 
163
  app = FastAPI(title="Transformer Summarizer Demo", version="1.0.0")
@@ -173,6 +184,7 @@ class SummarizeResponse(BaseModel):
173
  summary: str
174
  backend: str
175
  target_length: int | None
 
176
 
177
 
178
  @app.get("/health")
@@ -192,11 +204,13 @@ def summarize(req: SummarizeRequest):
192
  summary=result.summary,
193
  backend=result.backend,
194
  target_length=result.used_target_length,
 
195
  )
196
 
197
 
198
  @app.get("/")
199
  def root():
 
200
  html = """
201
  <!DOCTYPE html>
202
  <html lang="zh-CN">
@@ -294,6 +308,7 @@ def root():
294
  <p>这是一个基于 Transformer 的中文文本摘要演示系统。你可以通过下面两个按钮进入接口文档或检查服务状态,也可以直接调用摘要接口。</p>
295
  <p>当前模型:<code>{engine.model_name}</code></p>
296
  <p>当前后端:<code>{engine.backend_name}</code></p>
 
297
 
298
  <div class="btns">
299
  <a class="btn primary" href="/docs" target="_blank" rel="noreferrer">打开接口文档</a>
@@ -309,9 +324,9 @@ def root():
309
  "text": "这里放一段较长的中文文本",
310
  "target_length": 120
311
  }</code></pre>
312
- <p>4. 点击 <code>Execute</code> 后查看返回的摘要结果。</p>
313
- <p>5. 如果想确认服务是否正常,可点击 <code>检查服务状态</code>,返回 <code>ok</code> 即表示运行正常。</p>
314
- <p>6. 如果健康检查里的 <code>backend</code> 仍然是 <code>fallback</code>,说明 Transformer 模型没有成功加载,查看 <code>load_error</code> 的原因。</p>
315
  <div class="meta">
316
  提示:如果文本里有换行,请确保是合法 JSON。建议直接在 Swagger 页面提交,避免手写 JSON 出错。
317
  </div>
 
26
  summary: str
27
  backend: str
28
  used_target_length: Optional[int]
29
+ error: Optional[str] = None
30
 
31
 
32
  class SummarizationConfig:
 
110
  self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
111
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
112
  self.model.to(self.device)
113
+ self.model.eval()
114
  self.backend_name = "transformer"
115
  self.load_error = None
116
  except Exception as exc:
 
131
  backend="transformer",
132
  used_target_length=target_length,
133
  )
134
+ except Exception as exc:
135
+ logger.exception("Transformer generation failed")
136
+ return SummaryOutput(
137
+ summary=self.fallback.summarize(text, target_length=target_length),
138
+ backend="fallback",
139
+ used_target_length=target_length,
140
+ error=f"{type(exc).__name__}: {exc}",
141
+ )
142
  return SummaryOutput(
143
  summary=self.fallback.summarize(text, target_length=target_length),
144
  backend="fallback",
 
146
  )
147
 
148
  def _summarize_with_transformer(self, text: str, target_length: int | None) -> str:
149
+ prompt = text
150
  inputs = self.tokenizer(
151
  prompt,
152
  return_tensors="pt",
153
  truncation=True,
154
+ max_length=512,
155
  )
156
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
157
+ max_new_tokens = max(48, min(192, int((target_length or 120) * 1.1)))
158
+ with torch.no_grad():
159
+ generated = self.model.generate(
160
+ **inputs,
161
+ max_new_tokens=max_new_tokens,
162
+ num_beams=2,
163
+ no_repeat_ngram_size=3,
164
+ length_penalty=1.0,
165
+ early_stopping=True,
166
+ )
167
+ return self.tokenizer.decode(
168
+ generated[0],
169
+ skip_special_tokens=True,
170
+ clean_up_tokenization_spaces=True,
171
+ ).strip()
172
 
173
 
174
  app = FastAPI(title="Transformer Summarizer Demo", version="1.0.0")
 
184
  summary: str
185
  backend: str
186
  target_length: int | None
187
+ error: str | None = None
188
 
189
 
190
  @app.get("/health")
 
204
  summary=result.summary,
205
  backend=result.backend,
206
  target_length=result.used_target_length,
207
+ error=result.error,
208
  )
209
 
210
 
211
  @app.get("/")
212
  def root():
213
+ error_note = f"<p>最近一次生成错误:<code>{engine.load_error}</code></p>" if engine.load_error else ""
214
  html = """
215
  <!DOCTYPE html>
216
  <html lang="zh-CN">
 
308
  <p>这是一个基于 Transformer 的中文文本摘要演示系统。你可以通过下面两个按钮进入接口文档或检查服务状态,也可以直接调用摘要接口。</p>
309
  <p>当前模型:<code>{engine.model_name}</code></p>
310
  <p>当前后端:<code>{engine.backend_name}</code></p>
311
+ """ + error_note + """
312
 
313
  <div class="btns">
314
  <a class="btn primary" href="/docs" target="_blank" rel="noreferrer">打开接口文档</a>
 
324
  "text": "这里放一段较长的中文文本",
325
  "target_length": 120
326
  }</code></pre>
327
+ <p>4. 点击 <code>Execute</code> 后查看返回的摘要结果。</p>
328
+ <p>5. 如果想确认服务是否正常,可点击 <code>检查服务状态</code>,返回 <code>ok</code> 即表示运行正常。</p>
329
+ <p>6. 如果接口返回 <code>backend=fallback</code>,请查看响应里的 <code>error</code> 字段,这表示 Transformer 生成阶段失败,系统才会自动切回备用摘要。</p>
330
  <div class="meta">
331
  提示:如果文本里有换行,请确保是合法 JSON。建议直接在 Swagger 页面提交,避免手写 JSON 出错。
332
  </div>