manishw7 commited on
Commit
6cd700e
·
1 Parent(s): ced8950

Final: Stable All-In-One Suite (UI + Debug + Scoring + Fixes)

Browse files
Files changed (1) hide show
  1. app.py +48 -70
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import io
3
- import time
4
  import gradio as gr
5
  import torch
6
  import numpy as np
@@ -11,80 +10,72 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
11
  from cnn_model import CharacterClassifier
12
  from preprocessing import preprocess_for_ocr
13
 
14
- # --- SURGICAL MONKEY-PATCH FOR GRADIO 4.x ---
15
- try:
16
- import gradio_client.utils
17
- original_fn = gradio_client.utils.json_schema_to_python_type
18
- def patched_fn(schema, *args, **kwargs):
19
- if isinstance(schema, bool): return "Any"
20
- return original_fn(schema, *args, **kwargs)
21
- gradio_client.utils.json_schema_to_python_type = patched_fn
22
- except Exception: pass
23
- # --------------------------------------------
 
 
24
 
25
  # --- CONFIGURATION ---
26
  BASE_MODEL_ID = "paudelanil/trocr-devanagari-2"
27
  ADAPTER_ID = "manishw10/devgen-trocr-devanagari-lora"
28
  CNN_MODEL_PATH = "devanagari-cnn-classifier.pt"
29
-
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
- # --- ENGINE CORE ---
33
- print("System: Initializing Full Suite (Gradio 4.x Patched)...")
34
  processor = TrOCRProcessor.from_pretrained(BASE_MODEL_ID)
35
  base_model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL_ID)
36
-
37
- # Sync Token Configs
38
  base_model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
39
  base_model.config.pad_token_id = processor.tokenizer.pad_token_id
40
  base_model.config.eos_token_id = processor.tokenizer.sep_token_id
41
  base_model.config.vocab_size = base_model.config.decoder.vocab_size
42
 
43
- # Apply and Merge PEFT
44
  peft_model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
45
  try:
46
  model = peft_model.merge_and_unload()
47
  except Exception:
48
  model = peft_model
49
- model.to(device)
50
- model.eval()
51
-
52
  cnn_engine = CharacterClassifier(model_path=CNN_MODEL_PATH, device=device)
53
 
54
- # --- ORIGINAL ROUTING ---
55
  def _flood_fill(binary, visited, start_y, start_x, h, w):
56
  stack = [(start_y, start_x)]
57
  size = 0
58
  while stack:
59
  y, x = stack.pop()
60
  if y<0 or y>=h or x<0 or x>=w or visited[y,x] or not binary[y,x]: continue
61
- visited[y,x] = True
62
- size += 1
63
  stack.extend([(y+1,x),(y-1,x),(y,x+1),(y,x-1)])
64
  return size
65
 
66
- def count_blobs(binary, min_size=10):
67
- h, w = binary.shape
68
- visited = np.zeros_like(binary, dtype=bool)
69
- count = 0
70
  for y in range(h):
71
  for x in range(w):
72
  if binary[y,x] and not visited[y,x]:
73
  size = _flood_fill(binary, visited, y, x, h, w)
74
- if size >= min_size: count += 1
75
  return count
76
 
77
  def original_classify_input(image):
78
- gray = image.convert("L")
79
- arr = np.array(gray)
80
  threshold = min(arr.mean() * 0.75, 200)
81
  binary = (arr < threshold).astype(np.uint8)
82
  rows, cols = np.any(binary, axis=1), np.any(binary, axis=0)
83
  if not rows.any() or not cols.any(): return "character", 1.0, 1
84
- coords = np.column_stack(np.where(binary > 0))
85
- y0, x0 = coords.min(axis=0); y1, x1 = coords.max(axis=0)
86
  w, h = x1-x0+1, y1-y0+1
87
- ar, bc = w/h, count_blobs(binary, min_size=max(binary.size * 0.001, 10))
88
  is_char = True
89
  if ar > 2.5: is_char = False
90
  elif ar > 1.8 and bc >= 3: is_char = False
@@ -97,55 +88,44 @@ def original_classify_input(image):
97
 
98
  def get_confidence_html(confidence):
99
  color = "#10b981" if confidence > 0.9 else "#f59e0b" if confidence > 0.7 else "#ef4444"
100
- return f"""
101
- <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; padding: 15px; background: rgba(0,0,0,0.2); border-radius: 20px;">
102
- <div style="position: relative; width: 100px; height: 100px;">
103
- <svg width="100" height="100" viewBox="0 0 100 100">
104
- <circle cx="50" cy="50" r="45" fill="none" stroke="rgba(255,255,255,0.1)" stroke-width="8" />
105
- <circle cx="50" cy="50" r="45" fill="none" stroke="{color}" stroke-width="8"
106
- stroke-dasharray="282.7" stroke-dashoffset="{282.7 * (1 - confidence)}"
107
- stroke-linecap="round" style="transition: stroke-dashoffset 1s ease-out;" />
108
- </svg>
109
- <div style="position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); font-size: 1.2rem; font-weight: bold; font-family: 'Outfit'; color: {color};">
110
- {int(confidence * 100)}%
111
- </div>
112
- </div>
113
- </div>
114
- """
115
 
116
  # --- PREDICT ---
117
  def predict(image, manual_mode):
118
  if image is None: return None, None, "Upload image.", "", ""
119
  buf = io.BytesIO(); image.save(buf, format="PNG")
120
- preprocessed_pil = preprocess_for_ocr(buf.getvalue())
121
  if manual_mode == "Automatic":
122
- mode, ar, bc = original_classify_input(preprocessed_pil)
123
- status = f"**System Insight**: {mode.upper()} detected (AR: {ar:.2f}, Blobs: {bc})"
124
  else:
125
  mode = manual_mode.lower(); status = f"**Manual Mode**: {mode.upper()}"
126
  try:
127
  if mode == "character" and cnn_engine.available:
128
- result = cnn_engine.predict(preprocessed_pil)
129
- return preprocessed_pil, result["text"], status, "CNN Classifier", get_confidence_html(result["confidence"])
130
  else:
131
- pixel_values = processor(preprocessed_pil, return_tensors="pt").pixel_values.to(device)
132
  with torch.no_grad():
133
- outputs = model.generate(pixel_values, num_beams=4, max_length=128, early_stopping=True, return_dict_in_generate=True, output_scores=True, decoder_start_token_id=model.config.decoder_start_token_id)
134
- transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
135
- avg_conf = float(torch.exp(transition_scores[0]).mean().item())
136
- text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
137
- return preprocessed_pil, text, status, "TrOCR + LoRA", get_confidence_html(avg_conf)
138
  except Exception as e:
139
- return preprocessed_pil, f"Error: {str(e)}", "Failed", "None", ""
140
 
141
- # --- PREMIUM CSS ---
142
  CSS = """
143
  @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@400;600&family=Inter:wght@400;500&display=swap');
144
- .gradio-container { background: linear-gradient(135deg, #0f172a 0%, #1e1b4b 100%) !important; color: white !important; font-family: 'Inter', sans-serif !important; }
145
- .premium-card { background: rgba(30, 41, 59, 0.7) !important; backdrop-filter: blur(12px); border: 1px solid rgba(255,255,255,0.1); border-radius: 24px; padding: 2rem; box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5); }
146
  .result-box { font-size: 3rem !important; font-weight: 600; text-align: center; color: #818cf8; background: transparent !important; border: none !important; }
147
- .btn-primary { background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%) !important; border: none !important; border-radius: 12px !important; font-family: 'Outfit', sans-serif !important; font-weight: 600 !important; }
148
- .diagnostic-panel { margin-top: 30px; border-top: 1px solid rgba(255,255,255,0.1); padding-top: 20px; }
149
  """
150
 
151
  with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
@@ -155,19 +135,17 @@ with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
155
  with gr.Column(scale=1):
156
  img_in = gr.Image(type="pil", label="Input Handwriting")
157
  mode_ctrl = gr.Radio(["Automatic", "Word", "Character"], value="Automatic", label="Logic Mode")
158
- sub_btn = gr.Button("Recognize", variant="primary", elem_classes="btn-primary")
159
  with gr.Column(scale=1):
160
  conf_html = gr.HTML()
161
  text_out = gr.Textbox(label="Result", elem_classes="result-box", interactive=False, show_label=False)
162
  status_md = gr.Markdown("Engine ready.")
163
  engine_txt = gr.Textbox(label="Active Model", interactive=False)
164
-
165
- with gr.Column(elem_classes="diagnostic-panel"):
166
  gr.Markdown("### 🛠️ Visual Debug: What the Model Sees")
167
  img_proc = gr.Image(type="pil", label="Preprocessed Input", interactive=False, show_label=False)
168
 
169
- # EVENT HANDLER (Now correctly inside the Blocks context)
170
  sub_btn.click(predict, [img_in, mode_ctrl], [img_proc, text_out, status_md, engine_txt, conf_html])
171
 
172
  if __name__ == "__main__":
173
- demo.launch()
 
1
  import os
2
  import io
 
3
  import gradio as gr
4
  import torch
5
  import numpy as np
 
10
  from cnn_model import CharacterClassifier
11
  from preprocessing import preprocess_for_ocr
12
 
13
+ # --- ROBUST GLOBAL PATCH FOR GRADIO 4.x ---
14
+ import gradio_client.utils
15
+ def robust_get_type(schema):
16
+ if isinstance(schema, bool): return "Any"
17
+ if not isinstance(schema, dict): return "Any"
18
+ if "const" in schema: return "Any"
19
+ return original_get_type(schema)
20
+
21
+ if hasattr(gradio_client.utils, "get_type"):
22
+ original_get_type = gradio_client.utils.get_type
23
+ gradio_client.utils.get_type = robust_get_type
24
+ # ------------------------------------------
25
 
26
  # --- CONFIGURATION ---
27
  BASE_MODEL_ID = "paudelanil/trocr-devanagari-2"
28
  ADAPTER_ID = "manishw10/devgen-trocr-devanagari-lora"
29
  CNN_MODEL_PATH = "devanagari-cnn-classifier.pt"
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
+ # --- ENGINE INITIALIZATION ---
33
+ print("System: Initializing Full Combined Suite...")
34
  processor = TrOCRProcessor.from_pretrained(BASE_MODEL_ID)
35
  base_model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL_ID)
 
 
36
  base_model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
37
  base_model.config.pad_token_id = processor.tokenizer.pad_token_id
38
  base_model.config.eos_token_id = processor.tokenizer.sep_token_id
39
  base_model.config.vocab_size = base_model.config.decoder.vocab_size
40
 
 
41
  peft_model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
42
  try:
43
  model = peft_model.merge_and_unload()
44
  except Exception:
45
  model = peft_model
46
+ model.to(device); model.eval()
 
 
47
  cnn_engine = CharacterClassifier(model_path=CNN_MODEL_PATH, device=device)
48
 
49
+ # --- ORIGINAL ROUTING LOGIC ---
50
  def _flood_fill(binary, visited, start_y, start_x, h, w):
51
  stack = [(start_y, start_x)]
52
  size = 0
53
  while stack:
54
  y, x = stack.pop()
55
  if y<0 or y>=h or x<0 or x>=w or visited[y,x] or not binary[y,x]: continue
56
+ visited[y,x] = True; size += 1
 
57
  stack.extend([(y+1,x),(y-1,x),(y,x+1),(y,x-1)])
58
  return size
59
 
60
+ def count_blobs(binary):
61
+ h, w = binary.shape; visited = np.zeros_like(binary, dtype=bool); count = 0
 
 
62
  for y in range(h):
63
  for x in range(w):
64
  if binary[y,x] and not visited[y,x]:
65
  size = _flood_fill(binary, visited, y, x, h, w)
66
+ if size >= max(binary.size * 0.001, 10): count += 1
67
  return count
68
 
69
  def original_classify_input(image):
70
+ gray = image.convert("L"); arr = np.array(gray)
 
71
  threshold = min(arr.mean() * 0.75, 200)
72
  binary = (arr < threshold).astype(np.uint8)
73
  rows, cols = np.any(binary, axis=1), np.any(binary, axis=0)
74
  if not rows.any() or not cols.any(): return "character", 1.0, 1
75
+ y0, x0 = np.where(rows)[0][0], np.where(cols)[0][0]
76
+ y1, x1 = np.where(rows)[0][-1], np.where(cols)[0][-1]
77
  w, h = x1-x0+1, y1-y0+1
78
+ ar, bc = w/h, count_blobs(binary)
79
  is_char = True
80
  if ar > 2.5: is_char = False
81
  elif ar > 1.8 and bc >= 3: is_char = False
 
88
 
89
  def get_confidence_html(confidence):
90
  color = "#10b981" if confidence > 0.9 else "#f59e0b" if confidence > 0.7 else "#ef4444"
91
+ return f"""<div style="display: flex; flex-direction: column; align-items: center; background: rgba(0,0,0,0.2); border-radius: 20px; padding: 15px;">
92
+ <svg width="100" height="100" viewBox="0 0 100 100">
93
+ <circle cx="50" cy="50" r="45" fill="none" stroke="rgba(255,255,255,0.1)" stroke-width="8" />
94
+ <circle cx="50" cy="50" r="45" fill="none" stroke="{color}" stroke-width="8" stroke-dasharray="282.7" stroke-dashoffset="{282.7 * (1 - confidence)}" stroke-linecap="round" style="transition: stroke-dashoffset 1s;" />
95
+ <text x="50" y="55" font-family="Outfit" font-size="20" font-weight="bold" fill="{color}" text-anchor="middle">{int(confidence * 100)}%</text>
96
+ </svg>
97
+ </div>"""
 
 
 
 
 
 
 
 
98
 
99
  # --- PREDICT ---
100
  def predict(image, manual_mode):
101
  if image is None: return None, None, "Upload image.", "", ""
102
  buf = io.BytesIO(); image.save(buf, format="PNG")
103
+ pre_pil = preprocess_for_ocr(buf.getvalue())
104
  if manual_mode == "Automatic":
105
+ mode, ar, bc = original_classify_input(pre_pil)
106
+ status = f"**System**: {mode.upper()} detected (AR: {ar:.2f}, Blobs: {bc})"
107
  else:
108
  mode = manual_mode.lower(); status = f"**Manual Mode**: {mode.upper()}"
109
  try:
110
  if mode == "character" and cnn_engine.available:
111
+ res = cnn_engine.predict(pre_pil)
112
+ return pre_pil, res["text"], status, "CNN Classifier", get_confidence_html(res["confidence"])
113
  else:
114
+ pixel_values = processor(pre_pil, return_tensors="pt").pixel_values.to(device)
115
  with torch.no_grad():
116
+ out = model.generate(pixel_values, num_beams=4, max_length=128, early_stopping=True, return_dict_in_generate=True, output_scores=True, decoder_start_token_id=model.config.decoder_start_token_id)
117
+ scores = torch.exp(model.compute_transition_scores(out.sequences, out.scores, normalize_logits=True)[0])
118
+ txt = processor.batch_decode(out.sequences, skip_special_tokens=True)[0]
119
+ return pre_pil, txt, status, "TrOCR + LoRA", get_confidence_html(float(scores.mean().item()))
 
120
  except Exception as e:
121
+ return pre_pil, f"Error: {str(e)}", "Failed", "None", ""
122
 
123
+ # --- PREMIUM UI ---
124
  CSS = """
125
  @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@400;600&family=Inter:wght@400;500&display=swap');
126
+ .gradio-container { background: #0f172a !important; color: white !important; font-family: 'Inter', sans-serif !important; }
127
+ .premium-card { background: rgba(30, 41, 59, 0.7) !important; backdrop-filter: blur(12px); border: 1px solid rgba(255,255,255,0.1); border-radius: 24px; padding: 2rem; box-shadow: 0 25px 50px -12px rgba(0,0,0,0.5); }
128
  .result-box { font-size: 3rem !important; font-weight: 600; text-align: center; color: #818cf8; background: transparent !important; border: none !important; }
 
 
129
  """
130
 
131
  with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
 
135
  with gr.Column(scale=1):
136
  img_in = gr.Image(type="pil", label="Input Handwriting")
137
  mode_ctrl = gr.Radio(["Automatic", "Word", "Character"], value="Automatic", label="Logic Mode")
138
+ sub_btn = gr.Button("Recognize", variant="primary")
139
  with gr.Column(scale=1):
140
  conf_html = gr.HTML()
141
  text_out = gr.Textbox(label="Result", elem_classes="result-box", interactive=False, show_label=False)
142
  status_md = gr.Markdown("Engine ready.")
143
  engine_txt = gr.Textbox(label="Active Model", interactive=False)
144
+ with gr.Column():
 
145
  gr.Markdown("### 🛠️ Visual Debug: What the Model Sees")
146
  img_proc = gr.Image(type="pil", label="Preprocessed Input", interactive=False, show_label=False)
147
 
 
148
  sub_btn.click(predict, [img_in, mode_ctrl], [img_proc, text_out, status_md, engine_txt, conf_html])
149
 
150
  if __name__ == "__main__":
151
+ demo.launch(server_name="0.0.0.0", server_port=7860)