Cursor Agent AHAD commited on
Commit
2400aac
·
unverified ·
1 Parent(s): 53a9529

Use text branch for text-only inference on text+ASR models

Browse files

Pass inputs_asr=None instead of mirroring input as pseudo-ASR, which produced
undiacritized output. Implement direct encode/forward/decode to avoid Diac
predict_text empty-list .to() bug.

Co-authored-by: AHAD <ahad-m@users.noreply.github.com>

Files changed (2) hide show
  1. README.md +1 -1
  2. backend/app/services/inference.py +133 -21
README.md CHANGED
@@ -10,7 +10,7 @@
10
 
11
  **النموذج الافتراضي:** `rufaelfekadu/diac-transformer-text-asr-tashkeela-clartts`
12
 
13
- > **نموذج text+ASR:** عند إرسال نص فقط (بدون مخرجات Whisper)، يستخدم النظام النص نفسه كمدخل ASR بديلًا حتى يعمل الاستنتاج.
14
 
15
  ## المتطلبات
16
 
 
10
 
11
  **النموذج الافتراضي:** `rufaelfekadu/diac-transformer-text-asr-tashkeela-clartts`
12
 
13
+ > **نموذج text+ASR:** عند إرسال نص فقط (بدون مخرجات Whisper)، يُستخدم فرع النص في النموذج (`inputs_asr=None`). لنتائج أفضل على الكلام، مرّر نص ASR من Whisper في حقل ASR (قريبًا) أو استخدم نموذج `text-only`.
14
 
15
  ## المتطلبات
16
 
backend/app/services/inference.py CHANGED
@@ -4,6 +4,8 @@ from __future__ import annotations
4
 
5
  from typing import TYPE_CHECKING, Any
6
 
 
 
7
  if TYPE_CHECKING:
8
  from diac.models import DiacritizationModule
9
 
@@ -16,42 +18,152 @@ def _max_length(model: DiacritizationModule) -> int:
16
  return int(getattr(model.config.INFERENCE, "MAX_LENGTH", 500))
17
 
18
 
19
- def _asr_inputs_for_texts(model: DiacritizationModule, texts: list[str]) -> list[str]:
20
- """
21
- text+ASR checkpoints require ASR tensors during predict_text().
 
 
 
22
 
23
- When callers provide text only, mirror the undiacritized input as pseudo-ASR
24
- so encode_batch returns a tensor instead of an empty list (which breaks .to()).
25
- """
26
- if not _uses_asr(model):
27
- return []
28
- return list(texts)
29
 
 
 
 
 
 
 
 
30
 
31
- def predict_diacritized(model: DiacritizationModule, texts: list[str]) -> list[str]:
32
- """Run diacritization for one or more strings."""
 
 
33
  cleaned = [t.strip() for t in texts]
34
  if not cleaned:
35
  return []
36
 
 
37
  max_len = _max_length(model)
38
- asr_texts = _asr_inputs_for_texts(model, cleaned)
39
 
40
  if all(len(text) <= max_len for text in cleaned):
41
- outputs = model.predict_text(cleaned, asr_text=asr_texts)
42
- return _normalize_outputs(outputs, len(cleaned))
43
 
44
  results: list[str] = []
45
- for text, asr in zip(cleaned, asr_texts or [""] * len(cleaned), strict=True):
46
- asr_arg = asr if _uses_asr(model) else []
47
- window_out = model.predict_sliding_window(text, asr_text=asr_arg)
48
- if isinstance(window_out, list):
49
- results.append(window_out[0] if window_out else "")
50
- else:
51
- results.append(str(window_out))
52
  return results
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def _normalize_outputs(outputs: Any, expected: int) -> list[str]:
56
  if isinstance(outputs, list):
57
  if len(outputs) != expected:
 
4
 
5
  from typing import TYPE_CHECKING, Any
6
 
7
+ import torch
8
+
9
  if TYPE_CHECKING:
10
  from diac.models import DiacritizationModule
11
 
 
18
  return int(getattr(model.config.INFERENCE, "MAX_LENGTH", 500))
19
 
20
 
21
+ def _window_size(model: DiacritizationModule) -> int:
22
+ return int(getattr(model.config.INFERENCE, "WINDOW_SIZE", 50))
23
+
24
+
25
+ def _buffer_size(model: DiacritizationModule) -> int:
26
+ return int(getattr(model.config.INFERENCE, "BUFFER_SIZE", 25))
27
 
 
 
 
 
 
 
28
 
29
+ def predict_diacritized(
30
+ model: DiacritizationModule,
31
+ texts: list[str],
32
+ asr_texts: list[str] | None = None,
33
+ ) -> list[str]:
34
+ """
35
+ Run diacritization for one or more strings.
36
 
37
+ When no ASR text is supplied, the model's text branch is used (inputs_asr=None)
38
+ even for text+ASR checkpoints. This matches TransformerModel.forward and avoids
39
+ Diac predict_text bugs with empty ASR tensors.
40
+ """
41
  cleaned = [t.strip() for t in texts]
42
  if not cleaned:
43
  return []
44
 
45
+ asr = _normalize_asr_inputs(cleaned, asr_texts)
46
  max_len = _max_length(model)
 
47
 
48
  if all(len(text) <= max_len for text in cleaned):
49
+ return _predict_batch(model, cleaned, asr)
 
50
 
51
  results: list[str] = []
52
+ for text, asr_line in zip(cleaned, asr or [None] * len(cleaned), strict=True):
53
+ results.append(_predict_sliding_window(model, text, asr_line))
 
 
 
 
 
54
  return results
55
 
56
 
57
+ def _normalize_asr_inputs(
58
+ texts: list[str], asr_texts: list[str] | None
59
+ ) -> list[str] | None:
60
+ if not asr_texts:
61
+ return None
62
+ if len(asr_texts) != len(texts):
63
+ raise ValueError("asr_texts length must match texts length")
64
+ cleaned_asr = [t.strip() for t in asr_texts]
65
+ if not all(cleaned_asr):
66
+ return None
67
+ return cleaned_asr
68
+
69
+
70
+ def _predict_batch(
71
+ model: DiacritizationModule,
72
+ texts: list[str],
73
+ asr_texts: list[str] | None,
74
+ ) -> list[str]:
75
+ model.model.eval()
76
+ use_asr = _uses_asr(model) and asr_texts is not None
77
+
78
+ encoded_text, encoded_asr, _ = model.tokenizer.encode_batch(
79
+ texts,
80
+ asr_texts if use_asr else [],
81
+ padding=True,
82
+ )
83
+ encoded_text = encoded_text.to(model.device)
84
+ encoded_asr = _prepare_asr_tensor(model, encoded_asr, use_asr)
85
+
86
+ with torch.no_grad():
87
+ outputs = model.model(encoded_text, inputs_asr=encoded_asr)
88
+ predictions = outputs.argmax(dim=-1).cpu().tolist()
89
+
90
+ decoded = model.tokenizer.decode_batch(predictions, texts)
91
+ return _normalize_outputs(decoded, len(texts))
92
+
93
+
94
+ def _predict_sliding_window(
95
+ model: DiacritizationModule,
96
+ text: str,
97
+ asr_text: str | None,
98
+ ) -> str:
99
+ from diac.utils.text import remove_diacritics
100
+
101
+ model.model.eval()
102
+ text = remove_diacritics(text).strip()
103
+ if not text:
104
+ return ""
105
+
106
+ asr_text = asr_text or ""
107
+ max_len = _max_length(model)
108
+
109
+ if len(text) <= max_len:
110
+ batch = _predict_batch(model, [text], [asr_text] if asr_text else None)
111
+ return batch[0]
112
+
113
+ window_size = _window_size(model)
114
+ buffer_size = _buffer_size(model)
115
+ ratio = len(asr_text) / len(text) if asr_text else 1.0
116
+
117
+ start_idx = 0
118
+ end_idx = window_size
119
+ output = ""
120
+
121
+ while end_idx <= len(text):
122
+ start = max(0, start_idx - buffer_size)
123
+ end = min(len(text), end_idx + buffer_size)
124
+ end_idx = min(len(text), start_idx + window_size)
125
+
126
+ chunk = text[start:end]
127
+ chunk_asr = asr_text[int(start * ratio) : int(end * ratio)] if asr_text else ""
128
+
129
+ encoded_chunk, encoded_asr_chunk, _ = model.tokenizer.encode(
130
+ chunk,
131
+ chunk_asr or None,
132
+ return_tensor=True,
133
+ )
134
+ encoded_chunk = encoded_chunk.to(model.device)
135
+ encoded_asr_chunk = _prepare_asr_tensor(
136
+ model,
137
+ encoded_asr_chunk,
138
+ _uses_asr(model) and bool(chunk_asr),
139
+ )
140
+
141
+ with torch.no_grad():
142
+ outputs = model.model(encoded_chunk, inputs_asr=encoded_asr_chunk).squeeze(0)
143
+ predictions = outputs.argmax(dim=-1).cpu().tolist()
144
+
145
+ decoded_chunk = model.tokenizer.decode(
146
+ predictions[start_idx - start : end_idx - start],
147
+ chunk[start_idx - start : end_idx - start],
148
+ )
149
+ output += decoded_chunk
150
+ start_idx = end_idx
151
+
152
+ return output
153
+
154
+
155
+ def _prepare_asr_tensor(
156
+ model: DiacritizationModule,
157
+ encoded_asr: Any,
158
+ use_asr: bool,
159
+ ) -> torch.Tensor | None:
160
+ if not use_asr:
161
+ return None
162
+ if isinstance(encoded_asr, torch.Tensor):
163
+ return encoded_asr.to(model.device)
164
+ return None
165
+
166
+
167
  def _normalize_outputs(outputs: Any, expected: int) -> list[str]:
168
  if isinstance(outputs, list):
169
  if len(outputs) != expected: