manishw7 commited on
Commit
e8bc8af
·
1 Parent(s): 42d5462

Layout: Move visual diagnostic to main screen and preserve all logic

Browse files
Files changed (1) hide show
  1. app.py +53 -59
app.py CHANGED
@@ -19,7 +19,7 @@ CNN_MODEL_PATH = "devanagari-cnn-classifier.pt"
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  # --- ENGINE CORE ---
22
- print("System: Initializing Premium Engine with Confidence Scoring...")
23
  processor = TrOCRProcessor.from_pretrained(BASE_MODEL_ID)
24
  base_model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL_ID)
25
 
@@ -40,7 +40,29 @@ model.eval()
40
 
41
  cnn_engine = CharacterClassifier(model_path=CNN_MODEL_PATH, device=device)
42
 
43
- # --- ROUTING & ANALYTICS ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def original_classify_input(image):
45
  gray = image.convert("L")
46
  arr = np.array(gray)
@@ -48,27 +70,10 @@ def original_classify_input(image):
48
  binary = (arr < threshold).astype(np.uint8)
49
  rows, cols = np.any(binary, axis=1), np.any(binary, axis=0)
50
  if not rows.any() or not cols.any(): return "character", 1.0, 1
51
- rmin, rmax = np.where(rows)[0][[0, -1]], np.where(cols)[0][[0, -1]]
52
- w, h = rmax[1]-rmax[0]+1 if len(rmax)>1 else 32, rmin[1]-rmin[0]+1 if len(rmin)>1 else 32 # fallback
53
  coords = np.column_stack(np.where(binary > 0))
54
  y0, x0 = coords.min(axis=0); y1, x1 = coords.max(axis=0)
55
  w, h = x1-x0+1, y1-y0+1
56
- ar = w/max(h,1)
57
-
58
- # Re-implementing iterative flood fill for blob count
59
- visited = np.zeros_like(binary, dtype=bool)
60
- bc = 0
61
- for y in range(binary.shape[0]):
62
- for x in range(binary.shape[1]):
63
- if binary[y,x] and not visited[y,x]:
64
- stack = [(y,x)]
65
- size = 0
66
- while stack:
67
- py, px = stack.pop()
68
- if py<0 or py>=binary.shape[0] or px<0 or px>=binary.shape[1] or visited[py,px] or not binary[py,px]: continue
69
- visited[py,px] = True; size += 1
70
- stack.extend([(py+1,px),(py-1,px),(py,px+1),(py,px-1)])
71
- if size >= max(binary.size * 0.001, 10): bc += 1
72
 
73
  is_char = True
74
  if ar > 2.5: is_char = False
@@ -82,37 +87,33 @@ def original_classify_input(image):
82
 
83
  def get_confidence_html(confidence):
84
  color = "#10b981" if confidence > 0.9 else "#f59e0b" if confidence > 0.7 else "#ef4444"
85
- label = "High Certainty" if confidence > 0.9 else "Likely Correct" if confidence > 0.7 else "Review Recommended"
86
  return f"""
87
- <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; padding: 20px; background: rgba(0,0,0,0.2); border-radius: 20px; border: 1px solid rgba(255,255,255,0.1);">
88
- <div style="position: relative; width: 120px; height: 120px;">
89
- <svg width="120" height="120" viewBox="0 0 120 120">
90
- <circle cx="60" cy="60" r="54" fill="none" stroke="rgba(255,255,255,0.1)" stroke-width="8" />
91
- <circle cx="60" cy="60" r="54" fill="none" stroke="{color}" stroke-width="8"
92
- stroke-dasharray="339.29" stroke-dashoffset="{339.29 * (1 - confidence)}"
93
  stroke-linecap="round" style="transition: stroke-dashoffset 1s ease-out;" />
94
  </svg>
95
- <div style="position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); font-size: 1.5rem; font-weight: bold; font-family: 'Outfit'; color: {color};">
96
  {int(confidence * 100)}%
97
  </div>
98
  </div>
99
- <div style="margin-top: 10px; font-family: 'Inter'; font-size: 0.9rem; color: #94a3b8;">{label}</div>
100
  </div>
101
  """
102
 
103
  # --- PREDICT ---
104
  def predict(image, manual_mode):
105
  if image is None: return None, None, "Upload image.", "", ""
106
-
107
  buf = io.BytesIO()
108
  image.save(buf, format="PNG")
109
  preprocessed_pil = preprocess_for_ocr(buf.getvalue())
110
-
111
  if manual_mode == "Automatic":
112
  mode, ar, bc = original_classify_input(preprocessed_pil)
113
- status = f"**System Insight**: Auto-detected **{mode.upper()}** (AR: {ar:.2f}, Blobs: {bc})"
114
  else:
115
- mode = manual_mode.lower(); status = f"**System Insight**: Manual Mode **{mode.upper()}**"
116
 
117
  try:
118
  if mode == "character" and cnn_engine.available:
@@ -121,52 +122,45 @@ def predict(image, manual_mode):
121
  else:
122
  pixel_values = processor(preprocessed_pil, return_tensors="pt").pixel_values.to(device)
123
  with torch.no_grad():
124
- outputs = model.generate(
125
- pixel_values, num_beams=4, max_length=128, early_stopping=True,
126
- return_dict_in_generate=True, output_scores=True,
127
- decoder_start_token_id=model.config.decoder_start_token_id
128
- )
129
-
130
- # Calculate word-level confidence (Mirrored from local engine)
131
  transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
132
- confidences = torch.exp(transition_scores[0])
133
- avg_conf = float(confidences.mean().item())
134
-
135
  text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
136
  return preprocessed_pil, text, status, "TrOCR + LoRA", get_confidence_html(avg_conf)
137
  except Exception as e:
138
- return preprocessed_pil, f"Error: {str(e)}", "Process Failed", "None", ""
139
 
140
- # --- PREMIUM UI ---
141
  CSS = """
142
  @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@400;600&family=Inter:wght@400;500&display=swap');
143
- .gradio-container { background: #0f172a !important; color: white !important; font-family: 'Inter', sans-serif !important; }
144
  .premium-card { background: rgba(30, 41, 59, 0.7) !important; backdrop-filter: blur(12px); border: 1px solid rgba(255,255,255,0.1); border-radius: 24px; padding: 2rem; box-shadow: 0 25px 50px -12px rgba(0,0,0,0.5); }
145
- .result-box { font-size: 3rem !important; font-weight: 600; text-align: center; color: #818cf8; background: transparent !important; border: none !important; margin-top: 10px; }
146
- .btn-primary { background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%) !important; border: none !important; border-radius: 12px !important; font-family: 'Outfit', sans-serif !important; font-weight: 600 !important; padding: 12px !important; }
 
147
  """
148
 
149
  with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
150
  with gr.Column(elem_classes="premium-card"):
151
  gr.Markdown("# 🕉️ DevGen OCR")
152
- gr.Markdown("High-fidelity Devanagari recognition with real-time confidence metrics.")
153
-
154
  with gr.Row():
155
  with gr.Column(scale=1):
156
- img_in = gr.Image(type="pil", label="Input Document")
157
  mode_ctrl = gr.Radio(["Automatic", "Word", "Character"], value="Automatic", label="Logic Mode")
158
- sub_btn = gr.Button("Recognize Handwriting", variant="primary", elem_classes="btn-primary")
159
-
160
  with gr.Column(scale=1):
161
- conf_html = gr.HTML(label="Confidence Gauge")
162
- text_out = gr.Textbox(label="Result", elem_classes="result-box", interactive=False, show_label=False)
163
- status_md = gr.Markdown("Engine is ready.")
164
  engine_txt = gr.Textbox(label="Active Model", interactive=False)
165
 
166
- with gr.Accordion("🛠️ Technical Diagnostics", open=False):
167
- img_proc = gr.Image(type="pil", label="Processed Input", interactive=False)
 
 
 
168
 
169
- sub_btn.click(predict, [img_in, mode_ctrl], [img_proc, text_out, status_md, engine_txt, conf_html])
170
 
171
  if __name__ == "__main__":
172
  demo.launch()
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  # --- ENGINE CORE ---
22
+ print("System: Initializing Full Suite with Confidence and Visual Debug...")
23
  processor = TrOCRProcessor.from_pretrained(BASE_MODEL_ID)
24
  base_model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL_ID)
25
 
 
40
 
41
  cnn_engine = CharacterClassifier(model_path=CNN_MODEL_PATH, device=device)
42
 
43
+ # --- ORIGINAL ROUTING ---
44
+ def _flood_fill(binary, visited, start_y, start_x, h, w):
45
+ stack = [(start_y, start_x)]
46
+ size = 0
47
+ while stack:
48
+ y, x = stack.pop()
49
+ if y<0 or y>=h or x<0 or x>=w or visited[y,x] or not binary[y,x]: continue
50
+ visited[y,x] = True
51
+ size += 1
52
+ stack.extend([(y+1,x),(y-1,x),(y,x+1),(y,x-1)])
53
+ return size
54
+
55
+ def count_blobs(binary, min_size=10):
56
+ h, w = binary.shape
57
+ visited = np.zeros_like(binary, dtype=bool)
58
+ count = 0
59
+ for y in range(h):
60
+ for x in range(w):
61
+ if binary[y,x] and not visited[y,x]:
62
+ size = _flood_fill(binary, visited, y, x, h, w)
63
+ if size >= min_size: count += 1
64
+ return count
65
+
66
  def original_classify_input(image):
67
  gray = image.convert("L")
68
  arr = np.array(gray)
 
70
  binary = (arr < threshold).astype(np.uint8)
71
  rows, cols = np.any(binary, axis=1), np.any(binary, axis=0)
72
  if not rows.any() or not cols.any(): return "character", 1.0, 1
 
 
73
  coords = np.column_stack(np.where(binary > 0))
74
  y0, x0 = coords.min(axis=0); y1, x1 = coords.max(axis=0)
75
  w, h = x1-x0+1, y1-y0+1
76
+ ar, bc = w/h, count_blobs(binary, min_size=max(binary.size * 0.001, 10))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  is_char = True
79
  if ar > 2.5: is_char = False
 
87
 
88
  def get_confidence_html(confidence):
89
  color = "#10b981" if confidence > 0.9 else "#f59e0b" if confidence > 0.7 else "#ef4444"
 
90
  return f"""
91
+ <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; padding: 15px; background: rgba(0,0,0,0.2); border-radius: 20px;">
92
+ <div style="position: relative; width: 100px; height: 100px;">
93
+ <svg width="100" height="100" viewBox="0 0 100 100">
94
+ <circle cx="50" cy="50" r="45" fill="none" stroke="rgba(255,255,255,0.1)" stroke-width="8" />
95
+ <circle cx="50" cy="50" r="45" fill="none" stroke="{color}" stroke-width="8"
96
+ stroke-dasharray="282.7" stroke-dashoffset="{282.7 * (1 - confidence)}"
97
  stroke-linecap="round" style="transition: stroke-dashoffset 1s ease-out;" />
98
  </svg>
99
+ <div style="position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); font-size: 1.2rem; font-weight: bold; font-family: 'Outfit'; color: {color};">
100
  {int(confidence * 100)}%
101
  </div>
102
  </div>
 
103
  </div>
104
  """
105
 
106
  # --- PREDICT ---
107
  def predict(image, manual_mode):
108
  if image is None: return None, None, "Upload image.", "", ""
 
109
  buf = io.BytesIO()
110
  image.save(buf, format="PNG")
111
  preprocessed_pil = preprocess_for_ocr(buf.getvalue())
 
112
  if manual_mode == "Automatic":
113
  mode, ar, bc = original_classify_input(preprocessed_pil)
114
+ status = f"**System Insight**: {mode.upper()} detected (AR: {ar:.2f}, Blobs: {bc})"
115
  else:
116
+ mode = manual_mode.lower(); status = f"**Manual Mode**: {mode.upper()}"
117
 
118
  try:
119
  if mode == "character" and cnn_engine.available:
 
122
  else:
123
  pixel_values = processor(preprocessed_pil, return_tensors="pt").pixel_values.to(device)
124
  with torch.no_grad():
125
+ outputs = model.generate(pixel_values, num_beams=4, max_length=128, early_stopping=True, return_dict_in_generate=True, output_scores=True, decoder_start_token_id=model.config.decoder_start_token_id)
 
 
 
 
 
 
126
  transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
127
+ avg_conf = float(torch.exp(transition_scores[0]).mean().item())
 
 
128
  text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
129
  return preprocessed_pil, text, status, "TrOCR + LoRA", get_confidence_html(avg_conf)
130
  except Exception as e:
131
+ return preprocessed_pil, f"Error: {str(e)}", "Failed", "None", ""
132
 
133
+ # --- PREMIUM CSS ---
134
  CSS = """
135
  @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@400;600&family=Inter:wght@400;500&display=swap');
136
+ .gradio-container { background: linear-gradient(135deg, #0f172a 0%, #1e1b4b 100%) !important; color: white !important; font-family: 'Inter', sans-serif !important; }
137
  .premium-card { background: rgba(30, 41, 59, 0.7) !important; backdrop-filter: blur(12px); border: 1px solid rgba(255,255,255,0.1); border-radius: 24px; padding: 2rem; box-shadow: 0 25px 50px -12px rgba(0,0,0,0.5); }
138
+ .result-box { font-size: 3rem !important; font-weight: 600; text-align: center; color: #818cf8; background: transparent !important; border: none !important; }
139
+ .btn-primary { background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%) !important; border: none !important; border-radius: 12px !important; font-family: 'Outfit', sans-serif !important; font-weight: 600 !important; }
140
+ .diagnostic-panel { margin-top: 30px; border-top: 1px solid rgba(255,255,255,0.1); padding-top: 20px; }
141
  """
142
 
143
  with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
144
  with gr.Column(elem_classes="premium-card"):
145
  gr.Markdown("# 🕉️ DevGen OCR")
 
 
146
  with gr.Row():
147
  with gr.Column(scale=1):
148
+ img_in = gr.Image(type="pil", label="Input Handwriting")
149
  mode_ctrl = gr.Radio(["Automatic", "Word", "Character"], value="Automatic", label="Logic Mode")
150
+ sub_btn = gr.Button("Recognize", variant="primary", elem_classes="btn-primary")
 
151
  with gr.Column(scale=1):
152
+ conf_html = gr.HTML()
153
+ text_out = gr.Textbox(label="Recognition Result", elem_classes="result-box", interactive=False, show_label=False)
154
+ status_md = gr.Markdown("Engine ready.")
155
  engine_txt = gr.Textbox(label="Active Model", interactive=False)
156
 
157
+ with gr.Column(elem_classes="diagnostic-panel"):
158
+ gr.Markdown("### 🛠️ Visual Debug: What the Model Sees")
159
+ img_proc = gr.Image(type="pil", label="Preprocessed Input", interactive=False, show_label=False)
160
+
161
+ gr.Markdown("Built by DevGen Team.")
162
 
163
+ sub_btn.click(predict, [img_in, mode_ctrl], [img_proc, text_out, status_md, engine_txt, conf_html])
164
 
165
  if __name__ == "__main__":
166
  demo.launch()