hotchpotch commited on
Commit
8e8cb2d
·
1 Parent(s): 9e2f8cf
Files changed (6) hide show
  1. .python-version +1 -0
  2. Makefile +4 -0
  3. app.py +198 -0
  4. pyproject.toml +13 -0
  5. requirements.txt +250 -0
  6. uv.lock +0 -0
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
Makefile ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .PHONY: pre-deploy
2
+
3
+ pre-deploy:
4
+ uv export --no-hashes --frozen > requirements.txt
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import gradio as gr
10
+ import nltk
11
+
12
+
13
+ def _resolve_open_provence_paths() -> Path:
14
+ """Locate the open_provence repository so we can import bundled utilities."""
15
+
16
+ env_hint = os.getenv("OPEN_PROVENCE_REPO")
17
+ if env_hint:
18
+ candidate = Path(env_hint).expanduser().resolve()
19
+ if (candidate / "scripts" / "open_provence_infer.py").exists():
20
+ return candidate
21
+
22
+ cwd = Path(__file__).resolve().parent
23
+ for parent in [cwd, *cwd.parents]:
24
+ candidate = parent / "open_provence"
25
+ if (candidate / "scripts" / "open_provence_infer.py").exists():
26
+ return candidate
27
+
28
+ default_candidate = Path("/home/hotchpotch/src/github.com/hotchpotch/open_provence").resolve()
29
+ if (default_candidate / "scripts" / "open_provence_infer.py").exists():
30
+ return default_candidate
31
+
32
+ raise RuntimeError(
33
+ "open_provence repository not found. Set OPEN_PROVENCE_REPO to the repository root."
34
+ )
35
+
36
+
37
+ OPEN_PROVENCE_ROOT = _resolve_open_provence_paths()
38
+ if str(OPEN_PROVENCE_ROOT) not in sys.path:
39
+ sys.path.insert(0, str(OPEN_PROVENCE_ROOT))
40
+
41
+
42
+ from scripts.open_provence_infer import ( # type: ignore # noqa: E402
43
+ InputRecord,
44
+ _infer_record,
45
+ load_model,
46
+ )
47
+ from open_provence.modeling_open_provence_standalone import ( # type: ignore # noqa: E402
48
+ OpenProvenceModel,
49
+ resolve_inference_device,
50
+ )
51
+
52
+
53
+ def _ensure_nltk_punkt_resources() -> None:
54
+ """Download punkt resources on first run so inference never fails."""
55
+
56
+ for resource in ("punkt", "punkt_tab"):
57
+ try:
58
+ nltk.data.find(f"tokenizers/{resource}")
59
+ except LookupError:
60
+ nltk.download(resource, quiet=True)
61
+
62
+
63
+ DEFAULT_MODEL = "hotchpotch/open-provence-reranker-japanese-xsmall-v1"
64
+ DEFAULT_THRESHOLD = 0.1
65
+ DEFAULT_BATCH_SIZE = 8
66
+ SPLITTER_CHOICES = ("en", "ja")
67
+
68
+
69
+ DEVICE_HINT = os.getenv("OPEN_PROVENCE_INFER_DEVICE")
70
+ INFERENCE_DEVICE = resolve_inference_device(DEVICE_HINT)
71
+
72
+ _ensure_nltk_punkt_resources()
73
+
74
+
75
+ @functools.lru_cache(maxsize=4)
76
+ def _load_model_cached(model_identifier: str) -> OpenProvenceModel:
77
+ resolved_identifier = model_identifier.strip() or DEFAULT_MODEL
78
+ return load_model(resolved_identifier, device=INFERENCE_DEVICE)
79
+
80
+
81
+ def _format_summary(*, model_name: str, score: Optional[float], compression: Optional[float], duration: Optional[float]) -> str:
82
+ compression_percent = f"{compression:.1f}%" if isinstance(compression, (int, float)) else "N/A"
83
+ score_str = f"{score:.4f}" if isinstance(score, (int, float)) else "N/A"
84
+ duration_str = f"{duration:.2f}s" if isinstance(duration, (int, float)) else "N/A"
85
+ return (
86
+ f"**Model**: `{model_name}`\n"
87
+ f"- Score: {score_str}\n"
88
+ f"- Compression: {compression_percent}\n"
89
+ f"- Processing time: {duration_str}"
90
+ )
91
+
92
+
93
+ def run_inference(model_identifier: str, splitter_language: str, question: str, title: str, text: str) -> tuple[str, str]:
94
+ question_value = (question or "").strip()
95
+ text_value = (text or "").strip()
96
+ title_value = (title or "").strip() or None
97
+
98
+ if not question_value:
99
+ raise gr.Error("質問文を入力してください。")
100
+ if not text_value:
101
+ raise gr.Error("本文テキストを入力してください。")
102
+
103
+ model_name = (model_identifier or DEFAULT_MODEL).strip() or DEFAULT_MODEL
104
+
105
+ try:
106
+ model = _load_model_cached(model_name)
107
+ except Exception as exc: # pragma: no cover - user provided paths can fail
108
+ raise gr.Error(f"モデルの読み込みに失敗しました: {exc}") from exc
109
+
110
+ splitter_value = splitter_language if splitter_language in SPLITTER_CHOICES else None
111
+
112
+ record = InputRecord(question=question_value, text=text_value, title=title_value)
113
+ result = _infer_record(
114
+ model,
115
+ record,
116
+ threshold=DEFAULT_THRESHOLD,
117
+ language=splitter_value,
118
+ batch_size=DEFAULT_BATCH_SIZE,
119
+ first_line_as_title=False,
120
+ debug=False,
121
+ )
122
+
123
+ summary = _format_summary(
124
+ model_name=model_name,
125
+ score=result.score,
126
+ compression=result.compression_rate,
127
+ duration=result.total_seconds,
128
+ )
129
+
130
+ return summary, result.pruned_text
131
+
132
+
133
+ def build_interface() -> gr.Blocks:
134
+ splitter_default = "en"
135
+
136
+ with gr.Blocks(title="Open Provence Gradio Demo") as demo:
137
+ gr.Markdown(
138
+ """
139
+ ## Open Provence Reranker Demo
140
+
141
+ デフォルトモデル: `hotchpotch/open-provence-reranker-japanese-xsmall-v1`
142
+ CPUで動かすための最軽量モデル。英語と日本語、両方に対応しています。
143
+ """
144
+ )
145
+
146
+ with gr.Group():
147
+ model_input = gr.Textbox(
148
+ label="Model identifier",
149
+ value=DEFAULT_MODEL,
150
+ placeholder="例: hotchpotch/...",
151
+ info="CPUで動かすための最軽量モデル。英語と日本語、両方に対応しています。",
152
+ )
153
+ splitter_radio = gr.Radio(
154
+ label="Sentence splitter language",
155
+ choices=list(SPLITTER_CHOICES),
156
+ value=splitter_default,
157
+ info=(
158
+ "文分割時に使用する言語を固定します。モデル付属の分割器を使いたい場合は"
159
+ "事前に CLI で --splitter-lang を設定する構成を参照してください。"
160
+ ),
161
+ )
162
+
163
+ with gr.Group():
164
+ question_input = gr.Textbox(
165
+ label="Question",
166
+ placeholder="モデルに答えてほしい質問を入力してください。",
167
+ lines=2,
168
+ )
169
+ title_input = gr.Textbox(
170
+ label="Title (optional)",
171
+ placeholder="タイトルがあれば入力してください。",
172
+ lines=1,
173
+ )
174
+ text_input = gr.Textbox(
175
+ label="Text",
176
+ placeholder="要約・抽出したい本文を入力してください。",
177
+ lines=12,
178
+ )
179
+
180
+ run_button = gr.Button("Run inference", variant="primary")
181
+
182
+ summary_output = gr.Markdown(label="Summary")
183
+ pruned_output = gr.Textbox(label="Pruned Text", lines=12)
184
+
185
+ run_button.click(
186
+ fn=run_inference,
187
+ inputs=[model_input, splitter_radio, question_input, title_input, text_input],
188
+ outputs=[summary_output, pruned_output],
189
+ )
190
+
191
+ return demo
192
+
193
+
194
+ demo = build_interface()
195
+
196
+
197
+ if __name__ == "__main__":
198
+ demo.launch()
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "open-provence-demo"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "fast-bunkai>=0.1.1",
9
+ "gradio==5.49.1",
10
+ "nltk>=3.9.2",
11
+ "torch>=2.9.0",
12
+ "transformers>=4.57.1",
13
+ ]
requirements.txt ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv export --no-hashes --frozen
3
+ aiofiles==24.1.0
4
+ # via gradio
5
+ annotated-doc==0.0.3
6
+ # via fastapi
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.11.0
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ audioop-lts==0.2.2 ; python_full_version >= '3.13'
15
+ # via gradio
16
+ brotli==1.1.0
17
+ # via gradio
18
+ certifi==2025.10.5
19
+ # via
20
+ # httpcore
21
+ # httpx
22
+ # requests
23
+ charset-normalizer==3.4.4
24
+ # via requests
25
+ click==8.3.0
26
+ # via
27
+ # nltk
28
+ # typer
29
+ # uvicorn
30
+ colorama==0.4.6 ; sys_platform == 'win32'
31
+ # via
32
+ # click
33
+ # tqdm
34
+ fast-bunkai==0.1.1
35
+ # via open-provence-demo
36
+ fastapi==0.120.1
37
+ # via gradio
38
+ ffmpy==0.6.4
39
+ # via gradio
40
+ filelock==3.20.0
41
+ # via
42
+ # huggingface-hub
43
+ # torch
44
+ # transformers
45
+ fsspec==2025.9.0
46
+ # via
47
+ # gradio-client
48
+ # huggingface-hub
49
+ # torch
50
+ gradio==5.49.1
51
+ # via open-provence-demo
52
+ gradio-client==1.13.3
53
+ # via gradio
54
+ groovy==0.1.2
55
+ # via gradio
56
+ h11==0.16.0
57
+ # via
58
+ # httpcore
59
+ # uvicorn
60
+ hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
61
+ # via huggingface-hub
62
+ httpcore==1.0.9
63
+ # via httpx
64
+ httpx==0.28.1
65
+ # via
66
+ # gradio
67
+ # gradio-client
68
+ # safehttpx
69
+ huggingface-hub==0.36.0
70
+ # via
71
+ # gradio
72
+ # gradio-client
73
+ # tokenizers
74
+ # transformers
75
+ idna==3.11
76
+ # via
77
+ # anyio
78
+ # httpx
79
+ # requests
80
+ janome==0.5.0
81
+ # via fast-bunkai
82
+ jinja2==3.1.6
83
+ # via
84
+ # gradio
85
+ # torch
86
+ joblib==1.5.2
87
+ # via nltk
88
+ markdown-it-py==4.0.0
89
+ # via rich
90
+ markupsafe==3.0.3
91
+ # via
92
+ # gradio
93
+ # jinja2
94
+ mdurl==0.1.2
95
+ # via markdown-it-py
96
+ mpmath==1.3.0
97
+ # via sympy
98
+ networkx==3.5
99
+ # via torch
100
+ nltk==3.9.2
101
+ # via open-provence-demo
102
+ numpy==2.3.4
103
+ # via
104
+ # gradio
105
+ # pandas
106
+ # transformers
107
+ nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
108
+ # via
109
+ # nvidia-cudnn-cu12
110
+ # nvidia-cusolver-cu12
111
+ # torch
112
+ nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
113
+ # via torch
114
+ nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
115
+ # via torch
116
+ nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
117
+ # via torch
118
+ nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
119
+ # via torch
120
+ nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
121
+ # via torch
122
+ nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
123
+ # via torch
124
+ nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
125
+ # via torch
126
+ nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
127
+ # via torch
128
+ nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
129
+ # via
130
+ # nvidia-cusolver-cu12
131
+ # torch
132
+ nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
133
+ # via torch
134
+ nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
135
+ # via torch
136
+ nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
137
+ # via
138
+ # nvidia-cufft-cu12
139
+ # nvidia-cusolver-cu12
140
+ # nvidia-cusparse-cu12
141
+ # torch
142
+ nvidia-nvshmem-cu12==3.3.20 ; platform_machine == 'x86_64' and sys_platform == 'linux'
143
+ # via torch
144
+ nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
145
+ # via torch
146
+ orjson==3.11.4
147
+ # via gradio
148
+ packaging==25.0
149
+ # via
150
+ # gradio
151
+ # gradio-client
152
+ # huggingface-hub
153
+ # transformers
154
+ pandas==2.3.3
155
+ # via gradio
156
+ pillow==11.3.0
157
+ # via gradio
158
+ pydantic==2.11.10
159
+ # via
160
+ # fastapi
161
+ # gradio
162
+ pydantic-core==2.33.2
163
+ # via pydantic
164
+ pydub==0.25.1
165
+ # via gradio
166
+ pygments==2.19.2
167
+ # via rich
168
+ python-dateutil==2.9.0.post0
169
+ # via pandas
170
+ python-multipart==0.0.20
171
+ # via gradio
172
+ pytz==2025.2
173
+ # via pandas
174
+ pyyaml==6.0.3
175
+ # via
176
+ # gradio
177
+ # huggingface-hub
178
+ # transformers
179
+ regex==2025.10.23
180
+ # via
181
+ # nltk
182
+ # transformers
183
+ requests==2.32.5
184
+ # via
185
+ # huggingface-hub
186
+ # transformers
187
+ rich==14.2.0
188
+ # via typer
189
+ ruff==0.14.2
190
+ # via gradio
191
+ safehttpx==0.1.7
192
+ # via gradio
193
+ safetensors==0.6.2
194
+ # via transformers
195
+ semantic-version==2.10.0
196
+ # via gradio
197
+ setuptools==80.9.0
198
+ # via torch
199
+ shellingham==1.5.4
200
+ # via typer
201
+ six==1.17.0
202
+ # via python-dateutil
203
+ sniffio==1.3.1
204
+ # via anyio
205
+ starlette==0.48.0
206
+ # via
207
+ # fastapi
208
+ # gradio
209
+ sympy==1.14.0
210
+ # via torch
211
+ tokenizers==0.22.1
212
+ # via transformers
213
+ tomlkit==0.13.3
214
+ # via gradio
215
+ torch==2.9.0
216
+ # via open-provence-demo
217
+ tqdm==4.67.1
218
+ # via
219
+ # huggingface-hub
220
+ # nltk
221
+ # transformers
222
+ transformers==4.57.1
223
+ # via open-provence-demo
224
+ triton==3.5.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
225
+ # via torch
226
+ typer==0.20.0
227
+ # via gradio
228
+ typing-extensions==4.15.0
229
+ # via
230
+ # anyio
231
+ # fastapi
232
+ # gradio
233
+ # gradio-client
234
+ # huggingface-hub
235
+ # pydantic
236
+ # pydantic-core
237
+ # starlette
238
+ # torch
239
+ # typer
240
+ # typing-inspection
241
+ typing-inspection==0.4.2
242
+ # via pydantic
243
+ tzdata==2025.2
244
+ # via pandas
245
+ urllib3==2.5.0
246
+ # via requests
247
+ uvicorn==0.38.0
248
+ # via gradio
249
+ websockets==15.0.1
250
+ # via gradio-client
uv.lock ADDED
The diff for this file is too large to render. See raw diff