Toya0421 commited on
Commit
c1745f2
·
verified ·
1 Parent(s): ad3c4eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -45
app.py CHANGED
@@ -3,6 +3,7 @@ import re
3
  import glob
4
  import csv
5
  import threading
 
6
  from datetime import datetime, timedelta
7
  from typing import Optional
8
 
@@ -14,20 +15,19 @@ from openai import OpenAI
14
  # 設定(元コード踏襲)
15
  # =========================
16
  API_KEY = os.getenv("API_KEY")
17
- BASE_URL = os.getenv("BASE_URL", "https://openrouter.ai/api/v1")
18
  MODEL = os.getenv("MODEL", "google/gemini-2.5-flash")
19
 
20
- # Hugging Face Spaces 永続ストレージ(Persistent Storage 有効なら /data 永続
21
- OUT_DIR = os.getenv("OUT_DIR", "/data")
 
22
  os.makedirs(OUT_DIR, exist_ok=True)
23
 
24
- # スコアCSV(任意だが便利)
25
  CSV_PATH = os.path.join(OUT_DIR, "rewrite_scores.csv")
26
 
27
  _word_re = re.compile(r"[A-Za-z]+(?:'[A-Za-z]+)?")
28
 
29
- # ★追加:書き換え本文をtxtで蓄積するフォルダ(要件:rewrite_passages)
30
- # 既定は /data/rewrite_passages(永続に残る)
31
  REWRITE_DIR = "rewrite_passages"
32
  os.makedirs(REWRITE_DIR, exist_ok=True)
33
 
@@ -38,7 +38,6 @@ if not API_KEY:
38
 
39
  client = OpenAI(base_url=BASE_URL, api_key=API_KEY)
40
 
41
- # 同時実行を軽く制限(Spacesで安定させる)
42
  REWRITE_CONCURRENCY = int(os.getenv("REWRITE_CONCURRENCY", "2"))
43
  _rewrite_sem = threading.Semaphore(REWRITE_CONCURRENCY)
44
 
@@ -83,10 +82,9 @@ excluding the title, author name, source information, chapter number, annotation
83
  {text}
84
  """
85
 
86
- # 安全な max_tokens 候補(大→小)
87
  max_tokens_candidates = [3000, 2000, 1500, 1000]
88
-
89
  last_error = None
 
90
  for mt in max_tokens_candidates:
91
  try:
92
  with _rewrite_sem:
@@ -94,22 +92,21 @@ excluding the title, author name, source information, chapter number, annotation
94
  model=MODEL,
95
  messages=[{"role": "user", "content": prompt}],
96
  temperature=0.4,
97
- max_tokens=mt
98
  )
99
  return resp.choices[0].message.content.strip()
100
 
101
  except Exception as e:
102
  last_error = e
103
- # 402(クレジット不足 or トークン過多)の場合は縮めて再試行
104
  if "402" in str(e) or "more credits" in str(e):
105
  continue
106
- else:
107
- raise e
108
 
109
- # 全部ダメなら最後のエラーを投げる
110
  raise RuntimeError(f"Rewrite failed after retries: {last_error}")
111
 
112
-
 
 
113
  def count_words_english(text: str) -> int:
114
  return len(_word_re.findall(text))
115
 
@@ -119,12 +116,11 @@ def compute_metrics(text: str) -> tuple[float, int]:
119
  return fre, wc
120
 
121
  # =========================
122
- # CSV追記(軽量・永続)
123
  # =========================
124
  _csv_lock = threading.Lock()
125
 
126
  def append_csv_row(row: dict):
127
- """rewrite_scores.csv に1行追記(ヘッダ無ければ作成)"""
128
  fieldnames = ["timestamp_jst", "Text#", "target_level", "flesch_reading_ease", "word_count", "rewritten_text"]
129
  with _csv_lock:
130
  exists = os.path.exists(CSV_PATH)
@@ -135,7 +131,7 @@ def append_csv_row(row: dict):
135
  w.writerow({k: row.get(k, "") for k in fieldnames})
136
 
137
  # =========================
138
- # ★追加:rewrite_passages に txt 追記
139
  # =========================
140
  _txt_lock = threading.Lock()
141
 
@@ -146,9 +142,6 @@ def append_rewrite_txt(
146
  word_count: int,
147
  rewritten_text: str,
148
  ):
149
- """
150
- rewrite_passages/Text_{id}.txt に結果を追記(同じText#の再実行も追記で残す)
151
- """
152
  ts = (datetime.utcnow() + timedelta(hours=9)).strftime("%Y-%m-%d %H:%M:%S")
153
  path = os.path.join(REWRITE_DIR, f"Text_{text_id}.txt")
154
 
@@ -158,11 +151,9 @@ def append_rewrite_txt(
158
  f"Flesch Reading Ease: {fre:.2f}\n"
159
  f"Word Count: {word_count}\n"
160
  f"Timestamp (JST): {ts}\n"
161
- f"Model: {MODEL}\n"
162
- f"\n"
163
  f"---- Rewritten Text ----\n"
164
- f"{rewritten_text}\n"
165
- f"\n"
166
  f"{'=' * 80}\n"
167
  )
168
 
@@ -194,10 +185,10 @@ def start(level: int):
194
  st = init_state()
195
  total = len(st["files"])
196
  if total == 0:
197
- return st, "passages/pg*.txt が見つかりません", "", "", "", gr.update(visible=False)
198
 
199
  msg = f"準備完了: {total}件。次に処理するのは #Text {st['files'][0][0]} です。"
200
- return st, msg, "", "", "", gr.update(visible=True)
201
 
202
  def run_one(level: int, state: dict):
203
  set_stop(False)
@@ -207,7 +198,7 @@ def run_one(level: int, state: dict):
207
  total = len(files)
208
 
209
  if idx >= total:
210
- return state, "全て処理済みです。", "", "", "", gr.update(visible=True)
211
 
212
  text_id, path = files[idx]
213
  original = load_text(path)
@@ -215,10 +206,8 @@ def run_one(level: int, state: dict):
215
  rewritten = rewrite_level(original, target_level=level)
216
  fre, wc = compute_metrics(rewritten)
217
 
218
- # JSTタイムスタンプ(+9)
219
  ts = (datetime.utcnow() + timedelta(hours=9)).strftime("%Y-%m-%d %H:%M:%S")
220
 
221
- # CSV追記
222
  append_csv_row({
223
  "timestamp_jst": ts,
224
  "Text#": text_id,
@@ -228,7 +217,6 @@ def run_one(level: int, state: dict):
228
  "rewritten_text": rewritten
229
  })
230
 
231
- # ★txt追記
232
  append_rewrite_txt(
233
  text_id=text_id,
234
  target_level=level,
@@ -248,10 +236,9 @@ def run_one(level: int, state: dict):
248
  f"Saved TXT: {os.path.join(REWRITE_DIR, f'Text_{text_id}.txt')}"
249
  )
250
  progress = f"{state['idx']} / {total}"
251
- return state, "1件処理しました。", header, progress, rewritten, gr.update(visible=True)
252
 
253
  def run_all(level: int, state: dict):
254
- """全件(または残り)を順次処理。途中で「停止」ボタンで止められる。"""
255
  set_stop(False)
256
 
257
  files = state.get("files", [])
@@ -259,7 +246,7 @@ def run_all(level: int, state: dict):
259
  total = len(files)
260
 
261
  if idx >= total:
262
- return state, "全て処理済みです。", "", f"{idx} / {total}", "", gr.update(visible=True)
263
 
264
  last_header = ""
265
  last_text = ""
@@ -267,7 +254,7 @@ def run_all(level: int, state: dict):
267
  while idx < total:
268
  if get_stop():
269
  state["idx"] = idx
270
- return state, "停止しました。", last_header, f"{idx} / {total}", last_text, gr.update(visible=True)
271
 
272
  text_id, path = files[idx]
273
  original = load_text(path)
@@ -286,7 +273,6 @@ def run_all(level: int, state: dict):
286
  "rewritten_text": rewritten
287
  })
288
 
289
- # ★txt追記
290
  append_rewrite_txt(
291
  text_id=text_id,
292
  target_level=level,
@@ -308,7 +294,7 @@ def run_all(level: int, state: dict):
308
  idx += 1
309
  state["idx"] = idx
310
 
311
- return state, "全件処理が完了しました。", last_header, f"{idx} / {total}", last_text, gr.update(visible=True)
312
 
313
  def stop():
314
  set_stop(True)
@@ -321,9 +307,6 @@ def reset_csv():
321
  return f"CSVを削除しました: {CSV_PATH}"
322
 
323
  def reset_rewrite_folder():
324
- """
325
- rewrite_passages を全消しは危険なので、ここでは「中の txt を削除」する実装。
326
- """
327
  removed = 0
328
  with _txt_lock:
329
  for fp in glob.glob(os.path.join(REWRITE_DIR, "Text_*.txt")):
@@ -334,6 +317,65 @@ def reset_rewrite_folder():
334
  pass
335
  return f"rewrite_passages の Text_*.txt を削除しました({removed}件): {REWRITE_DIR}"
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  # =========================
338
  # Gradio UI(Spaces向け)
339
  # =========================
@@ -358,15 +400,46 @@ with gr.Blocks() as demo:
358
  reset_btn = gr.Button("CSVリセット(削除)")
359
  reset_txt_btn = gr.Button("rewrite_passagesリセット(Text_*.txt削除)")
360
 
361
- gr.Markdown(f"📄 CSV保存先: `{CSV_PATH}`(SpacesのFilesに出ます)")
362
- gr.Markdown(f"📝 TXT保存先: `{REWRITE_DIR}`(SpacesのFilesに出ます)")
363
 
364
- start_btn.click(fn=start, inputs=[level], outputs=[state, status, header, progress, output_text])
365
- one_btn.click(fn=run_one, inputs=[level, state], outputs=[state, status, header, progress, output_text])
366
- all_btn.click(fn=run_all, inputs=[level, state], outputs=[state, status, header, progress, output_text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  stop_btn.click(fn=stop, inputs=[], outputs=[status])
368
  reset_btn.click(fn=reset_csv, inputs=[], outputs=[status])
369
  reset_txt_btn.click(fn=reset_rewrite_folder, inputs=[], outputs=[status])
370
 
 
 
 
 
 
 
371
  demo.queue(max_size=64)
372
  demo.launch()
 
3
  import glob
4
  import csv
5
  import threading
6
+ import shutil
7
  from datetime import datetime, timedelta
8
  from typing import Optional
9
 
 
15
  # 設定(元コード踏襲)
16
  # =========================
17
  API_KEY = os.getenv("API_KEY")
18
+ BASE_URL = "https://openrouter.ai/api/v1"
19
  MODEL = os.getenv("MODEL", "google/gemini-2.5-flash")
20
 
21
+ # Free Space では /data 永続ではない可能性が高いので、
22
+ # ダウンロード前提で「カレント」に保存してOK
23
+ OUT_DIR = os.getenv("OUT_DIR", ".")
24
  os.makedirs(OUT_DIR, exist_ok=True)
25
 
 
26
  CSV_PATH = os.path.join(OUT_DIR, "rewrite_scores.csv")
27
 
28
  _word_re = re.compile(r"[A-Za-z]+(?:'[A-Za-z]+)?")
29
 
30
+ # 書き換え本文をtxtで蓄積するフォルダ
 
31
  REWRITE_DIR = "rewrite_passages"
32
  os.makedirs(REWRITE_DIR, exist_ok=True)
33
 
 
38
 
39
  client = OpenAI(base_url=BASE_URL, api_key=API_KEY)
40
 
 
41
  REWRITE_CONCURRENCY = int(os.getenv("REWRITE_CONCURRENCY", "2"))
42
  _rewrite_sem = threading.Semaphore(REWRITE_CONCURRENCY)
43
 
 
82
  {text}
83
  """
84
 
 
85
  max_tokens_candidates = [3000, 2000, 1500, 1000]
 
86
  last_error = None
87
+
88
  for mt in max_tokens_candidates:
89
  try:
90
  with _rewrite_sem:
 
92
  model=MODEL,
93
  messages=[{"role": "user", "content": prompt}],
94
  temperature=0.4,
95
+ max_tokens=mt,
96
  )
97
  return resp.choices[0].message.content.strip()
98
 
99
  except Exception as e:
100
  last_error = e
 
101
  if "402" in str(e) or "more credits" in str(e):
102
  continue
103
+ raise
 
104
 
 
105
  raise RuntimeError(f"Rewrite failed after retries: {last_error}")
106
 
107
+ # =========================
108
+ # 指標(FRE + 単語数)
109
+ # =========================
110
  def count_words_english(text: str) -> int:
111
  return len(_word_re.findall(text))
112
 
 
116
  return fre, wc
117
 
118
  # =========================
119
+ # CSV追記
120
  # =========================
121
  _csv_lock = threading.Lock()
122
 
123
  def append_csv_row(row: dict):
 
124
  fieldnames = ["timestamp_jst", "Text#", "target_level", "flesch_reading_ease", "word_count", "rewritten_text"]
125
  with _csv_lock:
126
  exists = os.path.exists(CSV_PATH)
 
131
  w.writerow({k: row.get(k, "") for k in fieldnames})
132
 
133
  # =========================
134
+ # txt追記
135
  # =========================
136
  _txt_lock = threading.Lock()
137
 
 
142
  word_count: int,
143
  rewritten_text: str,
144
  ):
 
 
 
145
  ts = (datetime.utcnow() + timedelta(hours=9)).strftime("%Y-%m-%d %H:%M:%S")
146
  path = os.path.join(REWRITE_DIR, f"Text_{text_id}.txt")
147
 
 
151
  f"Flesch Reading Ease: {fre:.2f}\n"
152
  f"Word Count: {word_count}\n"
153
  f"Timestamp (JST): {ts}\n"
154
+ f"Model: {MODEL}\n\n"
 
155
  f"---- Rewritten Text ----\n"
156
+ f"{rewritten_text}\n\n"
 
157
  f"{'=' * 80}\n"
158
  )
159
 
 
185
  st = init_state()
186
  total = len(st["files"])
187
  if total == 0:
188
+ return st, "passages/pg*.txt が見つかりません", "", "", "", None, None
189
 
190
  msg = f"準備完了: {total}件。次に処理するのは #Text {st['files'][0][0]} です。"
191
+ return st, msg, "", "", "", None, None
192
 
193
  def run_one(level: int, state: dict):
194
  set_stop(False)
 
198
  total = len(files)
199
 
200
  if idx >= total:
201
+ return state, "全て処理済みです。", "", "", "", None, None
202
 
203
  text_id, path = files[idx]
204
  original = load_text(path)
 
206
  rewritten = rewrite_level(original, target_level=level)
207
  fre, wc = compute_metrics(rewritten)
208
 
 
209
  ts = (datetime.utcnow() + timedelta(hours=9)).strftime("%Y-%m-%d %H:%M:%S")
210
 
 
211
  append_csv_row({
212
  "timestamp_jst": ts,
213
  "Text#": text_id,
 
217
  "rewritten_text": rewritten
218
  })
219
 
 
220
  append_rewrite_txt(
221
  text_id=text_id,
222
  target_level=level,
 
236
  f"Saved TXT: {os.path.join(REWRITE_DIR, f'Text_{text_id}.txt')}"
237
  )
238
  progress = f"{state['idx']} / {total}"
239
+ return state, "1件処理しました。", header, progress, rewritten, None, None
240
 
241
  def run_all(level: int, state: dict):
 
242
  set_stop(False)
243
 
244
  files = state.get("files", [])
 
246
  total = len(files)
247
 
248
  if idx >= total:
249
+ return state, "全て処理済みです。", "", f"{idx} / {total}", "", None, None
250
 
251
  last_header = ""
252
  last_text = ""
 
254
  while idx < total:
255
  if get_stop():
256
  state["idx"] = idx
257
+ return state, "停止しました。", last_header, f"{idx} / {total}", last_text, None, None
258
 
259
  text_id, path = files[idx]
260
  original = load_text(path)
 
273
  "rewritten_text": rewritten
274
  })
275
 
 
276
  append_rewrite_txt(
277
  text_id=text_id,
278
  target_level=level,
 
294
  idx += 1
295
  state["idx"] = idx
296
 
297
+ return state, "全件処理が完了しました。", last_header, f"{idx} / {total}", last_text, None, None
298
 
299
  def stop():
300
  set_stop(True)
 
307
  return f"CSVを削除しました: {CSV_PATH}"
308
 
309
  def reset_rewrite_folder():
 
 
 
310
  removed = 0
311
  with _txt_lock:
312
  for fp in glob.glob(os.path.join(REWRITE_DIR, "Text_*.txt")):
 
317
  pass
318
  return f"rewrite_passages の Text_*.txt を削除しました({removed}件): {REWRITE_DIR}"
319
 
320
+ # =========================
321
+ # ★追加:ダウンロード機能
322
+ # =========================
323
+ def list_generated_txt_files() -> list[str]:
324
+ files = sorted(glob.glob(os.path.join(REWRITE_DIR, "Text_*.txt")))
325
+ return [os.path.basename(f) for f in files]
326
+
327
+ def build_single_txt_path(selected_name: str) -> str:
328
+ """
329
+ 選択されたtxtをダウンロード用に返す(gr.Fileに渡す)
330
+ """
331
+ path = os.path.join(REWRITE_DIR, selected_name)
332
+ if not os.path.exists(path):
333
+ raise FileNotFoundError(f"Not found: {path}")
334
+ return path
335
+
336
+ def build_zip_of_txts(mode: str, n_last: int) -> str:
337
+ """
338
+ mode:
339
+ - 'all' : 全txt
340
+ - 'last_n' : 最新N個
341
+ """
342
+ files = sorted(glob.glob(os.path.join(REWRITE_DIR, "Text_*.txt")), key=os.path.getmtime)
343
+ if not files:
344
+ raise FileNotFoundError("No generated txt files yet.")
345
+
346
+ if mode == "last_n":
347
+ files = files[-max(1, int(n_last)):]
348
+
349
+ # ZIPは毎回作り直す
350
+ zip_path = os.path.join(OUT_DIR, "rewrite_passages.zip")
351
+ if os.path.exists(zip_path):
352
+ os.remove(zip_path)
353
+
354
+ # 一時ディレクトリに集めてからzip(shutil.make_archiveはディレクトリ単位)
355
+ tmp_dir = os.path.join(OUT_DIR, "_zip_tmp")
356
+ if os.path.exists(tmp_dir):
357
+ shutil.rmtree(tmp_dir)
358
+ os.makedirs(tmp_dir, exist_ok=True)
359
+
360
+ for fp in files:
361
+ shutil.copy(fp, os.path.join(tmp_dir, os.path.basename(fp)))
362
+
363
+ shutil.make_archive(os.path.join(OUT_DIR, "rewrite_passages"), "zip", tmp_dir)
364
+
365
+ # 後片付け
366
+ shutil.rmtree(tmp_dir, ignore_errors=True)
367
+
368
+ return zip_path
369
+
370
+ def download_csv() -> str:
371
+ if not os.path.exists(CSV_PATH):
372
+ raise FileNotFoundError("rewrite_scores.csv is not created yet.")
373
+ return CSV_PATH
374
+
375
+ def refresh_txt_dropdown() -> gr.Dropdown:
376
+ names = list_generated_txt_files()
377
+ return gr.Dropdown(choices=names, value=(names[-1] if names else None))
378
+
379
  # =========================
380
  # Gradio UI(Spaces向け)
381
  # =========================
 
400
  reset_btn = gr.Button("CSVリセット(削除)")
401
  reset_txt_btn = gr.Button("rewrite_passagesリセット(Text_*.txt削除)")
402
 
403
+ gr.Markdown("## 📥 Download")
404
+ gr.Markdown("Free Space では Files 生成物が見えないことがあるので、ここからダウンロードしてください。")
405
 
406
+ with gr.Row():
407
+ refresh_btn = gr.Button("txt一覧を更新")
408
+ txt_dropdown = gr.Dropdown(choices=list_generated_txt_files(), label="生成済み txt を選択(1個ダウンロード)")
409
+ download_one_btn = gr.Button("選択したtxtをダウンロード")
410
+ download_one_file = gr.File(label="Download (single txt)")
411
+
412
+ with gr.Row():
413
+ zip_mode = gr.Radio(
414
+ choices=[("全txtをまとめてZIP", "all"), ("最新N個をZIP", "last_n")],
415
+ value="all",
416
+ label="ZIPの作り方"
417
+ )
418
+ last_n = gr.Number(value=5, precision=0, label="N(最新N個の場合)")
419
+ download_zip_btn = gr.Button("ZIPを作ってダウンロード")
420
+ download_zip_file = gr.File(label="Download (zip)")
421
+
422
+ with gr.Row():
423
+ download_csv_btn = gr.Button("CSV(rewrite_scores.csv)をダウンロード")
424
+ download_csv_file = gr.File(label="Download (csv)")
425
+
426
+ # 既存の表示(参考)
427
+ gr.Markdown(f"📄 CSVパス: `{CSV_PATH}`")
428
+ gr.Markdown(f"📝 TXTフォルダ: `{REWRITE_DIR}`")
429
+
430
+ # ---- 既存ボタン ----
431
+ start_btn.click(fn=start, inputs=[level], outputs=[state, status, header, progress, output_text, download_one_file, download_zip_file])
432
+ one_btn.click(fn=run_one, inputs=[level, state], outputs=[state, status, header, progress, output_text, download_one_file, download_zip_file])
433
+ all_btn.click(fn=run_all, inputs=[level, state], outputs=[state, status, header, progress, output_text, download_one_file, download_zip_file])
434
  stop_btn.click(fn=stop, inputs=[], outputs=[status])
435
  reset_btn.click(fn=reset_csv, inputs=[], outputs=[status])
436
  reset_txt_btn.click(fn=reset_rewrite_folder, inputs=[], outputs=[status])
437
 
438
+ # ---- Download UI ----
439
+ refresh_btn.click(fn=refresh_txt_dropdown, inputs=[], outputs=[txt_dropdown])
440
+ download_one_btn.click(fn=build_single_txt_path, inputs=[txt_dropdown], outputs=[download_one_file])
441
+ download_zip_btn.click(fn=build_zip_of_txts, inputs=[zip_mode, last_n], outputs=[download_zip_file])
442
+ download_csv_btn.click(fn=download_csv, inputs=[], outputs=[download_csv_file])
443
+
444
  demo.queue(max_size=64)
445
  demo.launch()