czjun commited on
Commit
9704503
·
1 Parent(s): f5a5d88

feat: 更新模型配置和错误处理,添加protobuf依赖

Browse files
Files changed (2) hide show
  1. app.py +24 -6
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,5 +1,7 @@
1
  from __future__ import annotations
2
 
 
 
3
  from dataclasses import dataclass
4
  from typing import List, Optional
5
 
@@ -16,6 +18,9 @@ except Exception: # pragma: no cover
16
  AutoTokenizer = None
17
 
18
 
 
 
 
19
  @dataclass
20
  class SummaryOutput:
21
  summary: str
@@ -24,7 +29,7 @@ class SummaryOutput:
24
 
25
 
26
  class SummarizationConfig:
27
- model_name: str = "google/mt5-small"
28
  max_source_length: int = 1024
29
  max_target_length: int = 160
30
  num_beams: int = 4
@@ -85,17 +90,19 @@ class SimpleExtractiveSummarizer:
85
 
86
 
87
  class HybridSummarizer:
88
- def __init__(self, model_name: str = "google/mt5-small"):
89
- self.model_name = model_name
90
  self.backend_name = "fallback"
91
  self.tokenizer = None
92
  self.model = None
93
  self.fallback = SimpleExtractiveSummarizer()
94
  self.device = "cpu"
 
95
  self._try_load_transformer()
96
 
97
  def _try_load_transformer(self) -> None:
98
  if AutoTokenizer is None or AutoModelForSeq2SeqLM is None or torch is None:
 
99
  return
100
  try:
101
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
@@ -103,7 +110,10 @@ class HybridSummarizer:
103
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
104
  self.model.to(self.device)
105
  self.backend_name = "transformer"
106
- except Exception:
 
 
 
107
  self.tokenizer = None
108
  self.model = None
109
  self.backend_name = "fallback"
@@ -128,7 +138,7 @@ class HybridSummarizer:
128
  )
129
 
130
  def _summarize_with_transformer(self, text: str, target_length: int | None) -> str:
131
- prompt = f"请根据目标长度 {target_length or 120} 字生成摘要:{text}"
132
  inputs = self.tokenizer(
133
  prompt,
134
  return_tensors="pt",
@@ -167,7 +177,12 @@ class SummarizeResponse(BaseModel):
167
 
168
  @app.get("/health")
169
  def health():
170
- return {"status": "ok", "backend": engine.backend_name}
 
 
 
 
 
171
 
172
 
173
  @app.post("/summarize", response_model=SummarizeResponse)
@@ -277,6 +292,8 @@ def root():
277
  <div class="card">
278
  <h1>Transformer Summarizer Demo</h1>
279
  <p>这是一个基于 Transformer 的中文文本摘要演示系统。你可以通过下面两个按钮进入接口文档或检查服务状态,也可以直接调用摘要接口。</p>
 
 
280
 
281
  <div class="btns">
282
  <a class="btn primary" href="/docs" target="_blank" rel="noreferrer">打开接口文档</a>
@@ -294,6 +311,7 @@ def root():
294
  }</code></pre>
295
  <p>4. 点击 <code>Execute</code> 后查看返回的摘要结果。</p>
296
  <p>5. 如果想确认服务是否正常,可点击 <code>检查服务状态</code>,返回 <code>ok</code> 即表示运行正常。</p>
 
297
  <div class="meta">
298
  提示:如果文本里有换行,请确保是合法 JSON。建议直接在 Swagger 页面提交,避免手写 JSON 出错。
299
  </div>
 
1
  from __future__ import annotations
2
 
3
+ import logging
4
+ import os
5
  from dataclasses import dataclass
6
  from typing import List, Optional
7
 
 
18
  AutoTokenizer = None
19
 
20
 
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
  @dataclass
25
  class SummaryOutput:
26
  summary: str
 
29
 
30
 
31
  class SummarizationConfig:
32
+ model_name: str = os.getenv("MODEL_NAME", "IDEA-CCNL/Randeng-T5-Char-57M-MultiTask-Chinese")
33
  max_source_length: int = 1024
34
  max_target_length: int = 160
35
  num_beams: int = 4
 
90
 
91
 
92
  class HybridSummarizer:
93
+ def __init__(self, model_name: str | None = None):
94
+ self.model_name = os.getenv("MODEL_NAME", model_name or SummarizationConfig.model_name)
95
  self.backend_name = "fallback"
96
  self.tokenizer = None
97
  self.model = None
98
  self.fallback = SimpleExtractiveSummarizer()
99
  self.device = "cpu"
100
+ self.load_error: str | None = None
101
  self._try_load_transformer()
102
 
103
  def _try_load_transformer(self) -> None:
104
  if AutoTokenizer is None or AutoModelForSeq2SeqLM is None or torch is None:
105
+ self.load_error = "torch/transformers not installed"
106
  return
107
  try:
108
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
 
110
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
111
  self.model.to(self.device)
112
  self.backend_name = "transformer"
113
+ self.load_error = None
114
+ except Exception as exc:
115
+ self.load_error = f"{type(exc).__name__}: {exc}"
116
+ logger.exception("Failed to load transformer model: %s", self.model_name)
117
  self.tokenizer = None
118
  self.model = None
119
  self.backend_name = "fallback"
 
138
  )
139
 
140
  def _summarize_with_transformer(self, text: str, target_length: int | None) -> str:
141
+ prompt = f"summarize: {text}"
142
  inputs = self.tokenizer(
143
  prompt,
144
  return_tensors="pt",
 
177
 
178
  @app.get("/health")
179
  def health():
180
+ return {
181
+ "status": "ok",
182
+ "backend": engine.backend_name,
183
+ "model_name": engine.model_name,
184
+ "load_error": engine.load_error,
185
+ }
186
 
187
 
188
  @app.post("/summarize", response_model=SummarizeResponse)
 
292
  <div class="card">
293
  <h1>Transformer Summarizer Demo</h1>
294
  <p>这是一个基于 Transformer 的中文文本摘要演示系统。你可以通过下面两个按钮进入接口文档或检查服务状态,也可以直接调用摘要接口。</p>
295
+ <p>当前模型:<code>{engine.model_name}</code></p>
296
+ <p>当前后端:<code>{engine.backend_name}</code></p>
297
 
298
  <div class="btns">
299
  <a class="btn primary" href="/docs" target="_blank" rel="noreferrer">打开接口文档</a>
 
311
  }</code></pre>
312
  <p>4. 点击 <code>Execute</code> 后查看返回的摘要结果。</p>
313
  <p>5. 如果想确认服务是否正常,可点击 <code>检查服务状态</code>,返回 <code>ok</code> 即表示运行正常。</p>
314
+ <p>6. 如果健康检查里的 <code>backend</code> 仍然是 <code>fallback</code>,说明 Transformer 模型没有成功加载,请先查看 <code>load_error</code> 的原因。</p>
315
  <div class="meta">
316
  提示:如果文本里有换行,请确保是合法 JSON。建议直接在 Swagger 页面提交,避免手写 JSON 出错。
317
  </div>
requirements.txt CHANGED
@@ -4,4 +4,4 @@ pydantic>=2.7.0
4
  transformers>=4.41.0
5
  sentencepiece>=0.2.0
6
  torch>=2.1.0
7
-
 
4
  transformers>=4.41.0
5
  sentencepiece>=0.2.0
6
  torch>=2.1.0
7
+ protobuf>=4.25.0