ZennyKenny commited on
Commit
91450bb
·
verified ·
1 Parent(s): 5001bc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -76
app.py CHANGED
@@ -1,20 +1,37 @@
1
  import os
2
  from pathlib import Path
 
3
 
4
  import gradio as gr
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
7
  from peft import PeftModel
8
  import spaces # ZeroGPU
9
 
10
 
11
  # ========= Config =========
12
- # Base model + your LoRA adapter (override via Space Secrets if needed)
13
  MODEL_ID_BASE = os.getenv("BASE_MODEL_ID", "openai/gpt-oss-20b")
14
  ADAPTER_REPO = os.getenv("ADAPTER_REPO", "ZennyKenny/oss-20b-prereform-to-modern-ru-merged")
15
- ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoint-60") # change if your adapter folder differs
16
 
17
- # ========= Load external system prompt =========
 
 
 
 
 
 
 
 
 
 
 
18
  def _load_system_prompt():
19
  path = Path(__file__).with_name("text-prompt.py")
20
  default = (
@@ -26,31 +43,85 @@ def _load_system_prompt():
26
  try:
27
  ns = {}
28
  if path.exists():
29
- exec(path.read_text(encoding='utf-8'), ns)
30
  return ns.get("SYSTEM_PROMPT", default)
31
  except Exception:
32
  return default
33
 
34
  SYSTEM_PROMPT = _load_system_prompt()
35
 
36
- def build_prompt(text: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  return (
38
  f"{SYSTEM_PROMPT}\n\n"
39
- f"Текст (дореформ.):\n{text.strip()}\n\n"
40
  f"Текст (современная орфография):"
41
  )
42
 
43
- # ========= ZeroGPU inference =========
44
- @spaces.GPU(duration=180) # GPU is leased only while this function runs
45
- def _infer_zerogpu(prompt: str, gen_kwargs: dict) -> str:
46
- # Tokenizer from adapter repo (it contains tokenizer files)
47
- tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO, use_fast=True, trust_remote_code=True)
48
 
49
- # Ensure pad token exists; if not, align it with EOS (common for GPT-like)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  if tokenizer.pad_token_id is None:
51
  tokenizer.pad_token = tokenizer.eos_token
52
 
53
- # Load base model on GPU with appropriate dtype
54
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
55
  base = AutoModelForCausalLM.from_pretrained(
56
  MODEL_ID_BASE,
@@ -59,43 +130,42 @@ def _infer_zerogpu(prompt: str, gen_kwargs: dict) -> str:
59
  device_map="auto",
60
  )
61
 
62
- # Apply LoRA adapter from your repo/subfolder
63
  model = PeftModel.from_pretrained(base, ADAPTER_REPO, subfolder=ADAPTER_SUBFOLDER)
64
-
65
- # Optional: merge LoRA for faster generation
66
  try:
67
  model = model.merge_and_unload()
68
  except Exception:
69
  pass
70
 
71
- # Sync pad_token_id to model config to avoid warnings
72
  try:
73
  model.config.pad_token_id = tokenizer.pad_token_id
74
  except Exception:
75
  pass
76
 
77
- # ----- Tokenize & always pass attention_mask -----
78
  enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
79
  input_ids = enc["input_ids"].to(model.device)
80
  attention_mask = enc.get("attention_mask", torch.ones_like(input_ids)).to(model.device)
81
 
82
- # Reasonable defaults
83
- gen_kwargs = dict(gen_kwargs or {})
84
- gen_kwargs.setdefault("use_cache", True)
 
 
 
 
 
 
85
 
86
- # ----- Generate -----
87
  with torch.no_grad():
88
  out_ids = model.generate(
89
  input_ids=input_ids,
90
- attention_mask=attention_mask, # Key fix for pad==eos
91
  **gen_kwargs,
92
  )
93
 
94
- # Decode ONLY the continuation (exclude prompt tokens)
95
  continuation = out_ids[0, input_ids.shape[1]:]
96
  out = tokenizer.decode(continuation, skip_special_tokens=True).strip()
97
 
98
- # Fallback to full decode if continuation is empty (still no letter-replacement fallback)
99
  if not out:
100
  full = tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()
101
  marker = "Текст (современная орфография):"
@@ -103,73 +173,56 @@ def _infer_zerogpu(prompt: str, gen_kwargs: dict) -> str:
103
 
104
  return out
105
 
 
106
  # ========= Orchestrator =========
107
- def convert(text, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
108
- if not text or not text.strip():
109
- return ""
 
110
 
111
- prompt = build_prompt(text)
112
- gen_kwargs = dict(
113
- max_new_tokens=int(max_new_tokens),
114
- temperature=float(temperature),
115
- top_p=float(top_p),
116
- top_k=int(top_k),
117
- repetition_penalty=float(repetition_penalty),
118
- do_sample=True,
119
- )
 
 
120
 
121
- # ZeroGPU-only path; if it fails, show an informative message (no rule-based output)
122
- try:
123
- return _infer_zerogpu(prompt, gen_kwargs)
124
- except Exception as e:
125
- return f"[Ошибка ZeroGPU: {type(e).__name__}: {e}]"
126
 
127
  # ========= UI =========
128
- with gr.Blocks(title="Pre-reform → Modern Russian (ZeroGPU)") as demo:
129
  gr.Markdown(
130
  """
131
- # Преобразование дореформенной → современной орфографии
132
- Запросы выполняются на **ZeroGPU** (GPU выделяется только на время генерации).
 
 
133
  """
134
  )
135
 
136
  with gr.Row():
137
  with gr.Column():
138
- inp = gr.Textbox(
139
- label="Ввод: дореформенный текст",
 
 
140
  placeholder="Например: \"въ мирѣ сёмъ многа есть...\"",
141
- lines=10
142
  )
143
- with gr.Accordion("Параметры генерации", open=False):
144
- max_new_tokens = gr.Slider(16, 512, value=192, step=8, label="max_new_tokens")
145
- temperature = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="temperature")
146
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
147
- top_k = gr.Slider(0, 100, value=40, step=1, label="top_k")
148
- repetition_penalty = gr.Slider(1.0, 2.0, value=1.05, step=0.01, label="repetition_penalty")
149
- btn = gr.Button("Преобразовать", variant="primary")
150
  with gr.Column():
151
- out = gr.Textbox(label="Вывод: современная орфография", lines=14)
152
-
153
- gr.Examples(
154
- examples=[
155
- # Classic prose examples
156
- ["въ семъ домѣ обитало три семейства, и каждое имѣло свои обыкновенія."],
157
- ["Онъ шёлъ по узкой улѣцѣ, разсматривая вывѣски лавокъ и фонари."],
158
- ["въ мирѣ сёмъ многа есть, чего мудрецу и не снилось."],
159
- # Orthography stress tests
160
- ["Сей образъ мыслей былъ въ обычаѣ: въслѣдствіе того, что ѣще не наступило прояснѣніе."],
161
- ["Именіе его находилось на уѣздной окраинѣ; крестьяне имѣли обыкновеніе собираться къ вечеру."],
162
- ["Лѣтописи глаголютъ, яко многа бывало чудесъ на рѣкѣ сей."],
163
- ["Оный человѣкъ писалъ послѣднія строки при свѣтѣ фонаря, на улицѣ безлюдной."],
164
- ["Въ семъ письмѣ обрѣтёте вы извѣстія, коихъ до нынѣ не имѣли."],
165
- ],
166
- inputs=[inp],
167
- )
168
 
169
  btn.click(
170
- lambda t,a,b,c,d,e: convert(t, a, b, c, d, e),
171
- inputs=[inp, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
172
- outputs=[out],
 
173
  )
174
 
175
  if __name__ == "__main__":
 
1
  import os
2
  from pathlib import Path
3
+ from typing import Optional, Tuple
4
 
5
  import gradio as gr
6
  import torch
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ AutoProcessor,
11
+ Qwen2_5_VLForConditionalGeneration,
12
+ pipeline,
13
+ )
14
  from peft import PeftModel
15
  import spaces # ZeroGPU
16
 
17
 
18
  # ========= Config =========
 
19
  MODEL_ID_BASE = os.getenv("BASE_MODEL_ID", "openai/gpt-oss-20b")
20
  ADAPTER_REPO = os.getenv("ADAPTER_REPO", "ZennyKenny/oss-20b-prereform-to-modern-ru-merged")
21
+ ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoint-60")
22
 
23
+ OCR_MODEL_ID = os.getenv("OCR_MODEL_ID", "ChatDOC/OCRFlux-3B")
24
+
25
+ OCR_MAX_NEW_TOKENS = int(os.getenv("OCR_MAX_NEW_TOKENS", "6000"))
26
+ CONVERT_MAX_NEW_TOKENS = int(os.getenv("CONVERT_MAX_NEW_TOKENS", "6000"))
27
+
28
+ TEMPERATURE = float(os.getenv("CONVERT_TEMPERATURE", "0.2"))
29
+ TOP_P = float(os.getenv("CONVERT_TOP_P", "0.9"))
30
+ TOP_K = int(os.getenv("CONVERT_TOP_K", "40"))
31
+ REPETITION_PENALTY = float(os.getenv("CONVERT_REP_PENALTY", "1.05"))
32
+
33
+
34
+ # ========= Load prompts =========
35
  def _load_system_prompt():
36
  path = Path(__file__).with_name("text-prompt.py")
37
  default = (
 
43
  try:
44
  ns = {}
45
  if path.exists():
46
+ exec(path.read_text(encoding="utf-8"), ns)
47
  return ns.get("SYSTEM_PROMPT", default)
48
  except Exception:
49
  return default
50
 
51
  SYSTEM_PROMPT = _load_system_prompt()
52
 
53
+ # OCR prompt in its own file
54
+ def _load_ocr_prompt():
55
+ path = Path(__file__).with_name("ocr-prompt.py")
56
+ default = (
57
+ "Извлеки из изображения весь текст БУКВАЛЬНО и на русском языке. "
58
+ "Ничего не переводить и не исправлять. "
59
+ "Сохраняй дореформенную орфографию и специальные символы. "
60
+ "Верни только чистый текст (plain text)."
61
+ )
62
+ try:
63
+ ns = {}
64
+ if path.exists():
65
+ exec(path.read_text(encoding="utf-8"), ns)
66
+ return ns.get("OCR_PROMPT", default)
67
+ except Exception:
68
+ return default
69
+
70
+ OCR_PROMPT = _load_ocr_prompt()
71
+
72
+
73
+ def build_conversion_prompt(pre_reform_text: str) -> str:
74
  return (
75
  f"{SYSTEM_PROMPT}\n\n"
76
+ f"Текст (дореформ.):\n{pre_reform_text.strip()}\n\n"
77
  f"Текст (современная орфография):"
78
  )
79
 
 
 
 
 
 
80
 
81
+ # ========= ZeroGPU: OCR step =========
82
+ @spaces.GPU(duration=300) # 5 minutes
83
+ def _ocr_image_to_text(image) -> str:
84
+ processor = AutoProcessor.from_pretrained(OCR_MODEL_ID, trust_remote_code=True)
85
+
86
+ torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
87
+ ocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
88
+ OCR_MODEL_ID,
89
+ trust_remote_code=True,
90
+ torch_dtype=torch_dtype,
91
+ device_map="auto",
92
+ )
93
+
94
+ ocr_pipe = pipeline(
95
+ task="image-text-to-text",
96
+ model=ocr_model,
97
+ processor=processor,
98
+ )
99
+
100
+ out = ocr_pipe(
101
+ image,
102
+ prompt=OCR_PROMPT,
103
+ max_new_tokens=OCR_MAX_NEW_TOKENS,
104
+ temperature=0.0,
105
+ do_sample=False,
106
+ )
107
+
108
+ if isinstance(out, list) and len(out) > 0:
109
+ text = out[0].get("generated_text", "") or out[0].get("text", "")
110
+ elif isinstance(out, str):
111
+ text = out
112
+ else:
113
+ text = ""
114
+
115
+ return (text or "").strip()
116
+
117
+
118
+ # ========= ZeroGPU: Conversion step =========
119
+ @spaces.GPU(duration=300) # 5 minutes
120
+ def _convert_text_zerogpu(pre_reform_text: str) -> str:
121
+ tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO, use_fast=True, trust_remote_code=True)
122
  if tokenizer.pad_token_id is None:
123
  tokenizer.pad_token = tokenizer.eos_token
124
 
 
125
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
126
  base = AutoModelForCausalLM.from_pretrained(
127
  MODEL_ID_BASE,
 
130
  device_map="auto",
131
  )
132
 
 
133
  model = PeftModel.from_pretrained(base, ADAPTER_REPO, subfolder=ADAPTER_SUBFOLDER)
 
 
134
  try:
135
  model = model.merge_and_unload()
136
  except Exception:
137
  pass
138
 
 
139
  try:
140
  model.config.pad_token_id = tokenizer.pad_token_id
141
  except Exception:
142
  pass
143
 
144
+ prompt = build_conversion_prompt(pre_reform_text)
145
  enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
146
  input_ids = enc["input_ids"].to(model.device)
147
  attention_mask = enc.get("attention_mask", torch.ones_like(input_ids)).to(model.device)
148
 
149
+ gen_kwargs = dict(
150
+ max_new_tokens=CONVERT_MAX_NEW_TOKENS,
151
+ temperature=TEMPERATURE,
152
+ top_p=TOP_P,
153
+ top_k=TOP_K,
154
+ repetition_penalty=REPETITION_PENALTY,
155
+ do_sample=True,
156
+ use_cache=True,
157
+ )
158
 
 
159
  with torch.no_grad():
160
  out_ids = model.generate(
161
  input_ids=input_ids,
162
+ attention_mask=attention_mask,
163
  **gen_kwargs,
164
  )
165
 
 
166
  continuation = out_ids[0, input_ids.shape[1]:]
167
  out = tokenizer.decode(continuation, skip_special_tokens=True).strip()
168
 
 
169
  if not out:
170
  full = tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()
171
  marker = "Текст (современная орфография):"
 
173
 
174
  return out
175
 
176
+
177
  # ========= Orchestrator =========
178
+ def process(image, manual_text):
179
+ pre_reform_from_ocr = ""
180
+ if image is not None:
181
+ pre_reform_from_ocr = _ocr_image_to_text(image)
182
 
183
+ combined = ""
184
+ if manual_text and manual_text.strip():
185
+ combined = manual_text.strip()
186
+ if pre_reform_from_ocr:
187
+ combined = (combined + "\n\n" + pre_reform_from_ocr).strip() if combined else pre_reform_from_ocr
188
+
189
+ if not combined:
190
+ return "", ""
191
+
192
+ modern_text = _convert_text_zerogpu(combined)
193
+ return modern_text, pre_reform_from_ocr
194
 
 
 
 
 
 
195
 
196
  # ========= UI =========
197
+ with gr.Blocks(title="Pre-reform → Modern Russian (OCR + ZeroGPU)") as demo:
198
  gr.Markdown(
199
  """
200
+ # Преобразование дореформенной → современной орфографии (с OCR)
201
+ 1) Загрузите изображение с дореформенным текстом (фотография/скан), **или** вставьте текст вручную.
202
+ 2) Модель **OCRFlux-3B** извлечёт текст, затем **OSS-20B + LoRA** преобразует его в современную орфографию.
203
+ **Параметры генерации скрыты и настроены для длинных документов (≈ 6 000 токенов).**
204
  """
205
  )
206
 
207
  with gr.Row():
208
  with gr.Column():
209
+ img = gr.Image(label="Изображение с дореформенным текстом", type="pil")
210
+ manual = gr.Textbox(
211
+ label="(Необязательно) Вставьте дореформенный текст вручную",
212
+ lines=10,
213
  placeholder="Например: \"въ мирѣ сёмъ многа есть...\"",
 
214
  )
215
+ btn = gr.Button("Распознать и преобразовать", variant="primary")
 
 
 
 
 
 
216
  with gr.Column():
217
+ out_modern = gr.Textbox(label="Современная орфография (результат)", lines=18)
218
+ with gr.Accordion("Промежуточный текст из OCR (для проверки)", open=False):
219
+ out_ocr = gr.Textbox(label="Текст из OCRFlux-3B", lines=12)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  btn.click(
222
+ fn=process,
223
+ inputs=[img, manual],
224
+ outputs=[out_modern, out_ocr],
225
+ api_name="process",
226
  )
227
 
228
  if __name__ == "__main__":