Spaces:
Sleeping
Sleeping
czjun commited on
Commit ·
897c2d5
1
Parent(s): 9704503
feat: 添加错误处理和模型评估,优化摘要生成逻辑
Browse files- README.md +4 -0
- __pycache__/app.cpython-310.pyc +0 -0
- app.py +34 -19
README.md
CHANGED
|
@@ -9,3 +9,7 @@ license: mit
|
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 12 |
+
|
| 13 |
+
To force a specific transformer model in Spaces, set the `MODEL_NAME` environment variable, for example:
|
| 14 |
+
|
| 15 |
+
`IDEA-CCNL/Randeng-T5-Char-57M-MultiTask-Chinese`
|
__pycache__/app.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ
|
|
|
app.py
CHANGED
|
@@ -26,6 +26,7 @@ class SummaryOutput:
|
|
| 26 |
summary: str
|
| 27 |
backend: str
|
| 28 |
used_target_length: Optional[int]
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
class SummarizationConfig:
|
|
@@ -109,6 +110,7 @@ class HybridSummarizer:
|
|
| 109 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
| 110 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 111 |
self.model.to(self.device)
|
|
|
|
| 112 |
self.backend_name = "transformer"
|
| 113 |
self.load_error = None
|
| 114 |
except Exception as exc:
|
|
@@ -129,8 +131,14 @@ class HybridSummarizer:
|
|
| 129 |
backend="transformer",
|
| 130 |
used_target_length=target_length,
|
| 131 |
)
|
| 132 |
-
except Exception:
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
return SummaryOutput(
|
| 135 |
summary=self.fallback.summarize(text, target_length=target_length),
|
| 136 |
backend="fallback",
|
|
@@ -138,26 +146,29 @@ class HybridSummarizer:
|
|
| 138 |
)
|
| 139 |
|
| 140 |
def _summarize_with_transformer(self, text: str, target_length: int | None) -> str:
|
| 141 |
-
prompt =
|
| 142 |
inputs = self.tokenizer(
|
| 143 |
prompt,
|
| 144 |
return_tensors="pt",
|
| 145 |
truncation=True,
|
| 146 |
-
max_length=
|
| 147 |
)
|
| 148 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 149 |
-
max_new_tokens = max(
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
|
| 163 |
app = FastAPI(title="Transformer Summarizer Demo", version="1.0.0")
|
|
@@ -173,6 +184,7 @@ class SummarizeResponse(BaseModel):
|
|
| 173 |
summary: str
|
| 174 |
backend: str
|
| 175 |
target_length: int | None
|
|
|
|
| 176 |
|
| 177 |
|
| 178 |
@app.get("/health")
|
|
@@ -192,11 +204,13 @@ def summarize(req: SummarizeRequest):
|
|
| 192 |
summary=result.summary,
|
| 193 |
backend=result.backend,
|
| 194 |
target_length=result.used_target_length,
|
|
|
|
| 195 |
)
|
| 196 |
|
| 197 |
|
| 198 |
@app.get("/")
|
| 199 |
def root():
|
|
|
|
| 200 |
html = """
|
| 201 |
<!DOCTYPE html>
|
| 202 |
<html lang="zh-CN">
|
|
@@ -294,6 +308,7 @@ def root():
|
|
| 294 |
<p>这是一个基于 Transformer 的中文文本摘要演示系统。你可以通过下面两个按钮进入接口文档或检查服务状态,也可以直接调用摘要接口。</p>
|
| 295 |
<p>当前模型:<code>{engine.model_name}</code></p>
|
| 296 |
<p>当前后端:<code>{engine.backend_name}</code></p>
|
|
|
|
| 297 |
|
| 298 |
<div class="btns">
|
| 299 |
<a class="btn primary" href="/docs" target="_blank" rel="noreferrer">打开接口文档</a>
|
|
@@ -309,9 +324,9 @@ def root():
|
|
| 309 |
"text": "这里放一段较长的中文文本",
|
| 310 |
"target_length": 120
|
| 311 |
}</code></pre>
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
<div class="meta">
|
| 316 |
提示:如果文本里有换行,请确保是合法 JSON。建议直接在 Swagger 页面提交,避免手写 JSON 出错。
|
| 317 |
</div>
|
|
|
|
| 26 |
summary: str
|
| 27 |
backend: str
|
| 28 |
used_target_length: Optional[int]
|
| 29 |
+
error: Optional[str] = None
|
| 30 |
|
| 31 |
|
| 32 |
class SummarizationConfig:
|
|
|
|
| 110 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
| 111 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 112 |
self.model.to(self.device)
|
| 113 |
+
self.model.eval()
|
| 114 |
self.backend_name = "transformer"
|
| 115 |
self.load_error = None
|
| 116 |
except Exception as exc:
|
|
|
|
| 131 |
backend="transformer",
|
| 132 |
used_target_length=target_length,
|
| 133 |
)
|
| 134 |
+
except Exception as exc:
|
| 135 |
+
logger.exception("Transformer generation failed")
|
| 136 |
+
return SummaryOutput(
|
| 137 |
+
summary=self.fallback.summarize(text, target_length=target_length),
|
| 138 |
+
backend="fallback",
|
| 139 |
+
used_target_length=target_length,
|
| 140 |
+
error=f"{type(exc).__name__}: {exc}",
|
| 141 |
+
)
|
| 142 |
return SummaryOutput(
|
| 143 |
summary=self.fallback.summarize(text, target_length=target_length),
|
| 144 |
backend="fallback",
|
|
|
|
| 146 |
)
|
| 147 |
|
| 148 |
def _summarize_with_transformer(self, text: str, target_length: int | None) -> str:
|
| 149 |
+
prompt = text
|
| 150 |
inputs = self.tokenizer(
|
| 151 |
prompt,
|
| 152 |
return_tensors="pt",
|
| 153 |
truncation=True,
|
| 154 |
+
max_length=512,
|
| 155 |
)
|
| 156 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 157 |
+
max_new_tokens = max(48, min(192, int((target_length or 120) * 1.1)))
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
generated = self.model.generate(
|
| 160 |
+
**inputs,
|
| 161 |
+
max_new_tokens=max_new_tokens,
|
| 162 |
+
num_beams=2,
|
| 163 |
+
no_repeat_ngram_size=3,
|
| 164 |
+
length_penalty=1.0,
|
| 165 |
+
early_stopping=True,
|
| 166 |
+
)
|
| 167 |
+
return self.tokenizer.decode(
|
| 168 |
+
generated[0],
|
| 169 |
+
skip_special_tokens=True,
|
| 170 |
+
clean_up_tokenization_spaces=True,
|
| 171 |
+
).strip()
|
| 172 |
|
| 173 |
|
| 174 |
app = FastAPI(title="Transformer Summarizer Demo", version="1.0.0")
|
|
|
|
| 184 |
summary: str
|
| 185 |
backend: str
|
| 186 |
target_length: int | None
|
| 187 |
+
error: str | None = None
|
| 188 |
|
| 189 |
|
| 190 |
@app.get("/health")
|
|
|
|
| 204 |
summary=result.summary,
|
| 205 |
backend=result.backend,
|
| 206 |
target_length=result.used_target_length,
|
| 207 |
+
error=result.error,
|
| 208 |
)
|
| 209 |
|
| 210 |
|
| 211 |
@app.get("/")
|
| 212 |
def root():
|
| 213 |
+
error_note = f"<p>最近一次生成错误:<code>{engine.load_error}</code></p>" if engine.load_error else ""
|
| 214 |
html = """
|
| 215 |
<!DOCTYPE html>
|
| 216 |
<html lang="zh-CN">
|
|
|
|
| 308 |
<p>这是一个基于 Transformer 的中文文本摘要演示系统。你可以通过下面两个按钮进入接口文档或检查服务状态,也可以直接调用摘要接口。</p>
|
| 309 |
<p>当前模型:<code>{engine.model_name}</code></p>
|
| 310 |
<p>当前后端:<code>{engine.backend_name}</code></p>
|
| 311 |
+
""" + error_note + """
|
| 312 |
|
| 313 |
<div class="btns">
|
| 314 |
<a class="btn primary" href="/docs" target="_blank" rel="noreferrer">打开接口文档</a>
|
|
|
|
| 324 |
"text": "这里放一段较长的中文文本",
|
| 325 |
"target_length": 120
|
| 326 |
}</code></pre>
|
| 327 |
+
<p>4. 点击 <code>Execute</code> 后查看返回的摘要结果。</p>
|
| 328 |
+
<p>5. 如果想确认服务是否正常,可点击 <code>检查服务状态</code>,返回 <code>ok</code> 即表示运行正常。</p>
|
| 329 |
+
<p>6. 如果接口返回 <code>backend=fallback</code>,请查看响应里的 <code>error</code> 字段,这表示 Transformer 生成阶段失败,系统才会自动切回备用摘要。</p>
|
| 330 |
<div class="meta">
|
| 331 |
提示:如果文本里有换行,请确保是合法 JSON。建议直接在 Swagger 页面提交,避免手写 JSON 出错。
|
| 332 |
</div>
|