File size: 12,289 Bytes
c5e3761
 
9704503
 
c5e3761
 
 
 
f5a5d88
8d28a45
c5e3761
 
 
 
 
 
 
 
 
 
 
9704503
 
 
c5e3761
 
 
 
 
897c2d5
c5e3761
 
 
8d28a45
 
c5e3761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9704503
 
c5e3761
 
 
 
 
9704503
c5e3761
 
 
 
9704503
c5e3761
 
 
 
 
 
897c2d5
c5e3761
9704503
 
 
 
c5e3761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897c2d5
 
 
 
 
 
 
 
c5e3761
 
 
 
 
 
 
897c2d5
c5e3761
 
 
 
8d28a45
c5e3761
ac5d6e0
c5e3761
897c2d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5e3761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897c2d5
c5e3761
 
 
 
9704503
 
 
 
 
 
c5e3761
 
 
 
 
 
 
 
 
897c2d5
c5e3761
 
 
8d28a45
 
 
 
 
 
 
 
 
 
 
 
 
 
c5e3761
 
897c2d5
f5a5d88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d28a45
 
f5a5d88
 
 
8d28a45
 
 
 
 
 
 
f5a5d88
 
 
 
 
 
 
 
 
 
 
 
 
9704503
 
897c2d5
f5a5d88
 
 
 
 
 
 
 
 
8d28a45
 
 
f5a5d88
 
 
897c2d5
 
 
8d28a45
f5a5d88
8d28a45
f5a5d88
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
from __future__ import annotations

import logging
import os
from dataclasses import dataclass
from typing import List, Optional

from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi import Body, Query
from pydantic import BaseModel, Field

try:
    import torch
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
except Exception:  # pragma: no cover
    torch = None
    AutoModelForSeq2SeqLM = None
    AutoTokenizer = None


logger = logging.getLogger(__name__)


@dataclass
class SummaryOutput:
    summary: str
    backend: str
    used_target_length: Optional[int]
    error: Optional[str] = None


class SummarizationConfig:
    model_name: str = os.getenv("MODEL_NAME", "fnlp/bart-base-chinese")
    max_source_length: int = 512
    max_target_length: int = 160
    num_beams: int = 4
    no_repeat_ngram_size: int = 3
    length_penalty: float = 1.0
    fallback_sentences: int = 3


def normalize_text(text: str) -> str:
    return " ".join(text.replace("\u3000", " ").split())


def split_sentences(text: str) -> List[str]:
    import re

    parts = re.split(r"(?<=[。!?!?;;])\s*", text)
    return [p.strip() for p in parts if p.strip()]


def tokenize(text: str) -> List[str]:
    import re

    return re.findall(r"[\u4e00-\u9fff]+|[A-Za-z0-9]+", text.lower())


class SimpleExtractiveSummarizer:
    def __init__(self, max_sentences: int = 3):
        self.max_sentences = max_sentences

    def summarize(self, text: str, target_length: int | None = None) -> str:
        sentences = split_sentences(text)
        if not sentences:
            return ""
        if len(sentences) == 1:
            return sentences[0]

        freq = {}
        for sentence in sentences:
            for token in tokenize(sentence):
                freq[token] = freq.get(token, 0) + 1

        scored = []
        for idx, sentence in enumerate(sentences):
            tokens = tokenize(sentence)
            score = sum(freq.get(token, 0) for token in tokens) / max(1, len(tokens))
            scored.append((score, idx, sentence))

        scored.sort(key=lambda item: (-item[0], item[1]))
        selected = sorted(scored[: self.max_sentences], key=lambda item: item[1])
        kept: List[str] = []
        total = 0
        for _, _, sentence in selected:
            if target_length is not None and kept and total + len(sentence) > target_length:
                break
            kept.append(sentence)
            total += len(sentence)
        return "".join(kept or [selected[0][2]])


class HybridSummarizer:
    def __init__(self, model_name: str | None = None):
        self.model_name = os.getenv("MODEL_NAME", model_name or SummarizationConfig.model_name)
        self.backend_name = "fallback"
        self.tokenizer = None
        self.model = None
        self.fallback = SimpleExtractiveSummarizer()
        self.device = "cpu"
        self.load_error: str | None = None
        self._try_load_transformer()

    def _try_load_transformer(self) -> None:
        if AutoTokenizer is None or AutoModelForSeq2SeqLM is None or torch is None:
            self.load_error = "torch/transformers not installed"
            return
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model.to(self.device)
            self.model.eval()
            self.backend_name = "transformer"
            self.load_error = None
        except Exception as exc:
            self.load_error = f"{type(exc).__name__}: {exc}"
            logger.exception("Failed to load transformer model: %s", self.model_name)
            self.tokenizer = None
            self.model = None
            self.backend_name = "fallback"

    def summarize(self, text: str, target_length: int | None = None) -> SummaryOutput:
        text = normalize_text(text)
        if not text:
            return SummaryOutput(summary="", backend=self.backend_name, used_target_length=target_length)
        if self.backend_name == "transformer" and self.tokenizer and self.model:
            try:
                return SummaryOutput(
                    summary=self._summarize_with_transformer(text, target_length),
                    backend="transformer",
                    used_target_length=target_length,
                )
            except Exception as exc:
                logger.exception("Transformer generation failed")
                return SummaryOutput(
                    summary=self.fallback.summarize(text, target_length=target_length),
                    backend="fallback",
                    used_target_length=target_length,
                    error=f"{type(exc).__name__}: {exc}",
                )
        return SummaryOutput(
            summary=self.fallback.summarize(text, target_length=target_length),
            backend="fallback",
            used_target_length=target_length,
        )

    def _summarize_with_transformer(self, text: str, target_length: int | None) -> str:
        prompt = text
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=SummarizationConfig.max_source_length,
        )
        inputs.pop("token_type_ids", None)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        max_new_tokens = max(48, min(192, int((target_length or 120) * 1.1)))
        with torch.no_grad():
            generated = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                num_beams=2,
                no_repeat_ngram_size=3,
                length_penalty=1.0,
                early_stopping=True,
            )
        return self.tokenizer.decode(
            generated[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
        ).strip()


app = FastAPI(title="Transformer Summarizer Demo", version="1.0.0")
engine = HybridSummarizer()


class SummarizeRequest(BaseModel):
    text: str
    target_length: int | None = Field(default=120, ge=1, description="目标摘要长度")


class SummarizeResponse(BaseModel):
    summary: str
    backend: str
    target_length: int | None
    error: str | None = None


@app.get("/health")
def health():
    return {
        "status": "ok",
        "backend": engine.backend_name,
        "model_name": engine.model_name,
        "load_error": engine.load_error,
    }


@app.post("/summarize", response_model=SummarizeResponse)
def summarize(req: SummarizeRequest):
    result = engine.summarize(req.text, target_length=req.target_length)
    return SummarizeResponse(
        summary=result.summary,
        backend=result.backend,
        target_length=result.used_target_length,
        error=result.error,
    )


@app.post("/summarize-plain", response_model=SummarizeResponse)
def summarize_plain(
    text: str = Body(..., media_type="text/plain", description="直接粘贴原文,支持换行和空格"),
    target_length: int = Query(120, ge=1, description="目标摘要长度"),
):
    result = engine.summarize(text, target_length=target_length)
    return SummarizeResponse(
        summary=result.summary,
        backend=result.backend,
        target_length=result.used_target_length,
        error=result.error,
    )


@app.get("/")
def root():
    error_note = f"<p>最近一次生成错误:<code>{engine.load_error}</code></p>" if engine.load_error else ""
    html = """
    <!DOCTYPE html>
    <html lang="zh-CN">
    <head>
      <meta charset="utf-8" />
      <meta name="viewport" content="width=device-width, initial-scale=1" />
      <title>Transformer Summarizer Demo</title>
      <style>
        body {
          margin: 0;
          font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
          background: linear-gradient(135deg, #f7f8fc 0%, #eef4ff 100%);
          color: #1f2937;
        }
        .wrap {
          max-width: 920px;
          margin: 0 auto;
          padding: 56px 20px 72px;
        }
        .card {
          background: rgba(255, 255, 255, 0.92);
          border: 1px solid rgba(148, 163, 184, 0.25);
          border-radius: 20px;
          padding: 32px;
          box-shadow: 0 20px 60px rgba(15, 23, 42, 0.08);
          backdrop-filter: blur(8px);
        }
        h1 {
          margin: 0 0 12px;
          font-size: 34px;
        }
        h2 {
          margin: 24px 0 10px;
          font-size: 22px;
        }
        p {
          line-height: 1.75;
          margin: 10px 0;
        }
        .btns {
          display: flex;
          flex-wrap: wrap;
          gap: 14px;
          margin: 28px 0 18px;
        }
        a.btn {
          display: inline-block;
          padding: 14px 22px;
          border-radius: 12px;
          text-decoration: none;
          font-weight: 600;
          transition: transform 0.15s ease, box-shadow 0.15s ease;
        }
        a.btn:hover {
          transform: translateY(-1px);
        }
        .primary {
          background: #2563eb;
          color: white;
          box-shadow: 0 10px 20px rgba(37, 99, 235, 0.22);
        }
        .secondary {
          background: white;
          color: #2563eb;
          border: 1px solid rgba(37, 99, 235, 0.2);
        }
        .guide {
          margin-top: 26px;
          padding-top: 18px;
          border-top: 1px solid rgba(148, 163, 184, 0.25);
        }
        code {
          background: #eef2ff;
          padding: 2px 6px;
          border-radius: 6px;
        }
        pre {
          background: #f8fafc;
          color: #111827;
          padding: 16px;
          border-radius: 12px;
          overflow-x: auto;
          border: 1px solid rgba(148, 163, 184, 0.25);
        }
        pre code {
          background: transparent;
          padding: 0;
          border-radius: 0;
          color: inherit;
        }
        .meta {
          color: #6b7280;
          font-size: 14px;
          margin-top: 14px;
        }
      </style>
    </head>
    <body>
      <div class="wrap">
        <div class="card">
          <h1>Transformer Summarizer Demo</h1>
          <p>这是一个基于 Transformer 的中文文本摘要演示系统。你可以通过下面两个按钮进入接口文档或检查服务状态,也可以直接调用摘要接口。</p>
          <p>当前模型:<code>{engine.model_name}</code></p>
          <p>当前后端:<code>{engine.backend_name}</code></p>
          """ + error_note + """

          <div class="btns">
            <a class="btn primary" href="/docs" target="_blank" rel="noreferrer">打开接口文档</a>
            <a class="btn secondary" href="/health" target="_blank" rel="noreferrer">检查服务状态</a>
          </div>

          <div class="guide">
            <h2>使用指南</h2>
            <p>1. 点击 <code>打开接口文档</code>,进入 Swagger 页面。</p>
          <p>2. 找到 <code>POST /summarize</code>,点击 <code>Try it out</code>。</p>
          <p>3. 在请求体中填写文本和目标长度,例如:</p>
          <pre><code>{
  "text": "这里放一段较长的中文文本",
  "target_length": 120
}</code></pre>
          <p>4. 点击 <code>Execute</code> 后查看返回的摘要结果。</p>
          <p>5. 如果想确认服务是否正常,可点击 <code>检查服务状态</code>,返回 <code>ok</code> 即表示运行正常。</p>
          <p>6. 如果接口返回 <code>backend=fallback</code>,请查看响应里的 <code>error</code> 字段,这表示 Transformer 生成阶段失败,系统才会自动切回备用摘要。</p>
          <p>7. 如果原文包含大量换行或空格,建议直接使用 <code>POST /summarize-plain</code>,把正文当作纯文本提交,更适合粘贴文章正文。</p>
            <div class="meta">
              提示:<code>/summarize</code> 走 JSON,<code>/summarize-plain</code> 走纯文本。前者适合结构化调用,后者适合直接粘贴文章。
            </div>
          </div>
        </div>
      </div>
    </body>
    </html>
    """
    return HTMLResponse(content=html)