Toya0421 commited on
Commit
40925de
·
verified ·
1 Parent(s): cb9f047

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +280 -0
app.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import glob
4
+ import csv
5
+ import threading
6
+ from datetime import datetime, timedelta
7
+
8
+ import gradio as gr
9
+ import textstat
10
+ from openai import OpenAI
11
+
12
+ # =========================
13
+ # 設定(元コード踏襲)
14
+ # =========================
15
+ API_KEY = os.getenv("API_KEY")
16
+ BASE_URL = os.getenv("BASE_URL", "https://openrouter.ai/api/v1")
17
+ MODEL = os.getenv("MODEL", "google/gemini-2.5-flash")
18
+
19
+ # Hugging Face Spaces 永続ストレージ(推奨)
20
+ # Spaces の Persistent Storage を有効化している前提(/data が使える)
21
+ OUT_DIR = os.getenv("OUT_DIR", "/data")
22
+ os.makedirs(OUT_DIR, exist_ok=True)
23
+ CSV_PATH = os.path.join(OUT_DIR, "rewrite_scores.csv")
24
+
25
+ PASSAGES_DIR = os.getenv("PASSAGES_DIR", "passages")
26
+
27
+ if not API_KEY:
28
+ raise RuntimeError("API_KEY is not set (env: API_KEY)")
29
+
30
+ client = OpenAI(base_url=BASE_URL, api_key=API_KEY)
31
+
32
+ # 同時実行を軽く制限(Spacesで安定させる)
33
+ REWRITE_CONCURRENCY = int(os.getenv("REWRITE_CONCURRENCY", "2"))
34
+ _rewrite_sem = threading.Semaphore(REWRITE_CONCURRENCY)
35
+
36
+ _stop_flag_lock = threading.Lock()
37
+ _stop_flag = False
38
+
39
+
40
+ # =========================
41
+ # passages の列挙
42
+ # =========================
43
+ def list_passage_files_sorted(passages_dir: str) -> list[tuple[int, str]]:
44
+ pattern = os.path.join(passages_dir, "pg*.txt")
45
+ files = glob.glob(pattern)
46
+
47
+ items = []
48
+ for fp in files:
49
+ name = os.path.basename(fp)
50
+ m = re.match(r"pg(\d+)\.txt$", name)
51
+ if m:
52
+ items.append((int(m.group(1)), fp))
53
+ items.sort(key=lambda x: x[0])
54
+ return items
55
+
56
+
57
+ def load_text(path: str) -> str:
58
+ with open(path, "r", encoding="utf-8") as f:
59
+ return f.read()
60
+
61
+
62
+ # =========================
63
+ # 書き換え(プロンプト同一)
64
+ # =========================
65
+ def rewrite_level(text: str, target_level: int) -> str:
66
+ level_to_flesch = {1: 90, 2: 70, 3: 55, 4: 40, 5: 25}
67
+ target_flesch = level_to_flesch[int(target_level)]
68
+
69
+ prompt = f"""
70
+ Rewrite the following passage so it fits about {target_flesch} Flesch Reading Ease Score
71
+ - Extract only the portions of the text that should be read as the main body,
72
+ excluding the title, author name, source information, chapter number, annotations, and footers.
73
+ - When outputting, make sure sections divided by chapters, etc., are clearly distinguishable by leaving a blank line between them.
74
+ - Preserve the original meaning faithfully.
75
+ - Do not add new information or remove essential information.
76
+ - Output only the rewritten passage. Do not include explanations.
77
+ {text}
78
+ """
79
+
80
+ with _rewrite_sem:
81
+ resp = client.chat.completions.create(
82
+ model=MODEL,
83
+ messages=[{"role": "user", "content": prompt}],
84
+ temperature=0.4,
85
+ max_tokens=5000
86
+ )
87
+ return resp.choices[0].message.content.strip()
88
+
89
+
90
+ # =========================
91
+ # 指標(FRE + 単語数)
92
+ # =========================
93
+ _word_re = re.compile(r"[A-Za-z]+(?:'[A-Za-z]+)?")
94
+
95
+ def count_words_english(text: str) -> int:
96
+ return len(_word_re.findall(text))
97
+
98
+ def compute_metrics(text: str) -> tuple[float, int]:
99
+ fre = float(textstat.flesch_reading_ease(text))
100
+ wc = count_words_english(text)
101
+ return fre, wc
102
+
103
+
104
+ # =========================
105
+ # CSV追記(軽量・永続)
106
+ # =========================
107
+ _csv_lock = threading.Lock()
108
+
109
+ def append_csv_row(row: dict):
110
+ """
111
+ /data/rewrite_scores.csv に1行追記(ヘッダ無ければ作成)
112
+ """
113
+ fieldnames = ["timestamp_jst", "Text#", "target_level", "flesch_reading_ease", "word_count", "rewritten_text"]
114
+ with _csv_lock:
115
+ exists = os.path.exists(CSV_PATH)
116
+ with open(CSV_PATH, "a", encoding="utf-8", newline="") as f:
117
+ w = csv.DictWriter(f, fieldnames=fieldnames)
118
+ if not exists:
119
+ w.writeheader()
120
+ w.writerow({k: row.get(k, "") for k in fieldnames})
121
+
122
+
123
+ # =========================
124
+ # 停止フラグ
125
+ # =========================
126
+ def set_stop(flag: bool):
127
+ global _stop_flag
128
+ with _stop_flag_lock:
129
+ _stop_flag = flag
130
+
131
+ def get_stop() -> bool:
132
+ with _stop_flag_lock:
133
+ return _stop_flag
134
+
135
+
136
+ # =========================
137
+ # UIロジック
138
+ # =========================
139
+ def init_state():
140
+ files = list_passage_files_sorted(PASSAGES_DIR)
141
+ return {
142
+ "files": files, # [(text_id, path), ...]
143
+ "idx": 0, # 次に処理する位置
144
+ }
145
+
146
+ def start(level: int):
147
+ set_stop(False)
148
+ st = init_state()
149
+ total = len(st["files"])
150
+ if total == 0:
151
+ return st, "passages/pg*.txt が見つかりません", "", "", "", gr.update(visible=False)
152
+
153
+ msg = f"準備完了: {total}件。次に処理するのは #Text {st['files'][0][0]} です。"
154
+ return st, msg, "", "", "", gr.update(visible=True)
155
+
156
+ def run_one(level: int, state: dict):
157
+ set_stop(False)
158
+
159
+ files = state.get("files", [])
160
+ idx = int(state.get("idx", 0))
161
+ total = len(files)
162
+
163
+ if idx >= total:
164
+ return state, "全て処理済みです。", "", "", "", gr.update(visible=True)
165
+
166
+ text_id, path = files[idx]
167
+ original = load_text(path)
168
+
169
+ rewritten = rewrite_level(original, target_level=level)
170
+ fre, wc = compute_metrics(rewritten)
171
+
172
+ # JSTタイムスタンプ(+9)
173
+ ts = (datetime.utcnow() + timedelta(hours=9)).strftime("%Y-%m-%d %H:%M:%S")
174
+
175
+ append_csv_row({
176
+ "timestamp_jst": ts,
177
+ "Text#": text_id,
178
+ "target_level": level,
179
+ "flesch_reading_ease": f"{fre:.2f}",
180
+ "word_count": wc,
181
+ "rewritten_text": rewritten
182
+ })
183
+
184
+ state["idx"] = idx + 1
185
+
186
+ header = f"#Text {text_id}\nTarget Level: {level}\nFlesch Reading Ease: {fre:.2f}\nWord Count: {wc}\nSaved: {CSV_PATH}"
187
+ progress = f"{state['idx']} / {total}"
188
+ return state, "1件処理しました。", header, progress, rewritten, gr.update(visible=True)
189
+
190
+ def run_all(level: int, state: dict):
191
+ """
192
+ 全件(または残り)を順次処理。途中で「停止」ボタンで止められる。
193
+ """
194
+ set_stop(False)
195
+
196
+ files = state.get("files", [])
197
+ idx = int(state.get("idx", 0))
198
+ total = len(files)
199
+
200
+ if idx >= total:
201
+ return state, "全て処理済みです。", "", f"{idx} / {total}", "", gr.update(visible=True)
202
+
203
+ last_header = ""
204
+ last_text = ""
205
+
206
+ while idx < total:
207
+ if get_stop():
208
+ state["idx"] = idx
209
+ return state, "停止しました。", last_header, f"{idx} / {total}", last_text, gr.update(visible=True)
210
+
211
+ text_id, path = files[idx]
212
+ original = load_text(path)
213
+
214
+ rewritten = rewrite_level(original, target_level=level)
215
+ fre, wc = compute_metrics(rewritten)
216
+
217
+ ts = (datetime.utcnow() + timedelta(hours=9)).strftime("%Y-%m-%d %H:%M:%S")
218
+
219
+ append_csv_row({
220
+ "timestamp_jst": ts,
221
+ "Text#": text_id,
222
+ "target_level": level,
223
+ "flesch_reading_ease": f"{fre:.2f}",
224
+ "word_count": wc,
225
+ "rewritten_text": rewritten
226
+ })
227
+
228
+ last_header = f"#Text {text_id}\nTarget Level: {level}\nFlesch Reading Ease: {fre:.2f}\nWord Count: {wc}\nSaved: {CSV_PATH}"
229
+ last_text = rewritten
230
+
231
+ idx += 1
232
+ state["idx"] = idx
233
+
234
+ return state, "全件処理が完了しました。", last_header, f"{idx} / {total}", last_text, gr.update(visible=True)
235
+
236
+ def stop():
237
+ set_stop(True)
238
+ return "停止要求を受け付けました(処理中の1件が終わったタイミングで止まります)。"
239
+
240
+ def reset_csv():
241
+ with _csv_lock:
242
+ if os.path.exists(CSV_PATH):
243
+ os.remove(CSV_PATH)
244
+ return f"CSVを削除しました: {CSV_PATH}"
245
+
246
+
247
+ # =========================
248
+ # Gradio UI(Spaces向け)
249
+ # =========================
250
+ with gr.Blocks() as demo:
251
+ gr.Markdown("# 🔁 Passage Rewrite + FRE Scoring (HF Spaces)")
252
+
253
+ state = gr.State(init_state())
254
+
255
+ level = gr.Dropdown(choices=[1, 2, 3, 4, 5], value=1, label="Target Level (1..5)")
256
+ status = gr.Textbox(label="Status", interactive=False)
257
+ header = gr.Textbox(label="Result Header (#Text / FRE / Words)", lines=5, interactive=False)
258
+ progress = gr.Textbox(label="Progress", interactive=False)
259
+ output_text = gr.Textbox(label="Rewritten Text", lines=18, interactive=False)
260
+
261
+ with gr.Row():
262
+ start_btn = gr.Button("開始(ファイル読み込み)")
263
+ one_btn = gr.Button("次へ(1件処理)")
264
+ all_btn = gr.Button("全件実行(残りを処理)")
265
+ stop_btn = gr.Button("停止")
266
+ with gr.Row():
267
+ reset_btn = gr.Button("CSVリセット(削除)")
268
+ csv_hint = gr.Markdown(f"📄 CSV保存先: `{CSV_PATH}`(SpacesのFilesに出ます)")
269
+
270
+ # stop通知を見せる用
271
+ stop_note = gr.Markdown(visible=False)
272
+
273
+ start_btn.click(fn=start, inputs=[level], outputs=[state, status, header, progress, output_text, stop_note])
274
+ one_btn.click(fn=run_one, inputs=[level, state], outputs=[state, status, header, progress, output_text, stop_note])
275
+ all_btn.click(fn=run_all, inputs=[level, state], outputs=[state, status, header, progress, output_text, stop_note])
276
+ stop_btn.click(fn=stop, inputs=[], outputs=[status])
277
+ reset_btn.click(fn=reset_csv, inputs=[], outputs=[status])
278
+
279
+ demo.queue(max_size=64)
280
+ demo.launch()