manishw7 commited on
Commit
9ebb598
ยท
1 Parent(s): ecce7a8

Design: Final Premium Suite with Pro Mode Toggle

Browse files
Files changed (1) hide show
  1. app.py +46 -45
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import io
 
3
  import gradio as gr
4
  import torch
5
  import numpy as np
@@ -17,8 +18,8 @@ CNN_MODEL_PATH = "devanagari-cnn-classifier.pt"
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- # --- MODEL LOADING ---
21
- print("System: Loading Engine with Visual Debug...")
22
  processor = TrOCRProcessor.from_pretrained(BASE_MODEL_ID)
23
  base_model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL_ID)
24
 
@@ -32,16 +33,14 @@ base_model.config.vocab_size = base_model.config.decoder.vocab_size
32
  peft_model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
33
  try:
34
  model = peft_model.merge_and_unload()
35
- print("System: LoRA weights merged.")
36
  except Exception:
37
  model = peft_model
38
  model.to(device)
39
  model.eval()
40
 
41
- # Load CNN
42
  cnn_engine = CharacterClassifier(model_path=CNN_MODEL_PATH, device=device)
43
 
44
- # --- ORIGINAL ROUTING LOGIC ---
45
  def _flood_fill(binary, visited, start_y, start_x, h, w):
46
  stack = [(start_y, start_x)]
47
  size = 0
@@ -75,31 +74,32 @@ def original_classify_input(image):
75
  cmin, cmax = np.where(cols)[0][[0, -1]]
76
  w, h = cmax - cmin + 1, rmax - rmin + 1
77
  ar, bc = w/h, count_blobs(binary, min_size=max(binary.size * 0.001, 10))
 
78
  is_char = True
79
  if ar > 2.5: is_char = False
80
  elif ar > 1.8 and bc >= 3: is_char = False
81
  elif bc >= 4: is_char = False
82
  elif ar < 1.3 and bc <= 2: is_char = True
83
  elif bc == 1 and ar < 1.5: is_char = True
84
- elif ar < 1.75 and bc <= 2: is_char = True # <--- RESTORED THIS LINE
85
  elif ar > 1.6: is_char = False
 
86
  return ("character" if is_char else "word"), ar, bc
87
 
88
- # --- PREDICT ---
89
- def predict(image):
90
  if image is None: return None, None, "Upload image.", ""
91
 
92
- # 1. PREPROCESS (Critical!)
93
  buf = io.BytesIO()
94
  image.save(buf, format="PNG")
95
- image_bytes = buf.getvalue()
96
-
97
- preprocessed_pil = preprocess_for_ocr(image_bytes)
98
- if preprocessed_pil is None: return None, None, "Preprocessing Failed", ""
99
 
100
- # 2. ROUTE
101
- mode, ar, bc = original_classify_input(preprocessed_pil)
102
- status = f"Mode: {mode.upper()} | AR: {ar:.2f} | Blobs: {bc}"
 
 
 
103
 
104
  try:
105
  if mode == "character" and cnn_engine.available:
@@ -108,43 +108,44 @@ def predict(image):
108
  else:
109
  pixel_values = processor(preprocessed_pil, return_tensors="pt").pixel_values.to(device)
110
  with torch.no_grad():
111
- outputs = model.generate(
112
- pixel_values,
113
- num_beams=4,
114
- max_length=128,
115
- early_stopping=True,
116
- decoder_start_token_id=model.config.decoder_start_token_id
117
- )
118
- text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
119
  return preprocessed_pil, text, status, "TrOCR + LoRA"
120
  except Exception as e:
121
- return preprocessed_pil, f"Error: {str(e)}", "Inference Failed", "None"
122
 
123
- # --- UI ---
124
  CSS = """
125
- .gradio-container { background: #0f172a; color: white; font-family: 'Inter', sans-serif; }
126
- .panel { background: rgba(30, 41, 59, 0.8); border-radius: 20px; padding: 20px; border: 1px solid #334155; }
127
- .result-text { font-size: 2.2rem !important; font-weight: bold; color: #818cf8; text-align: center; background: rgba(0,0,0,0.3) !important; border-radius: 12px; }
 
 
 
128
  """
129
 
130
- with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
131
- gr.Markdown("# ๐Ÿ•‰๏ธ DevGen OCR โ€” Diagnostic Suite")
132
- gr.Markdown("See exactly what the engine sees to debug quality issues.")
133
-
134
- with gr.Row(elem_classes="panel"):
135
- with gr.Column():
136
- input_img = gr.Image(type="pil", label="1. Original Upload")
137
- run_btn = gr.Button("๐Ÿ” Run Diagnostic Recognition", variant="primary")
138
-
139
- with gr.Column():
140
- processed_img = gr.Image(type="pil", label="2. What the Model Sees", interactive=False)
141
- output_text = gr.Textbox(label="3. Recognition Result", elem_classes="result-text")
142
 
143
- with gr.Row(elem_classes="panel"):
144
- status_lbl = gr.Markdown("## System Insights\nReady to analyze.")
145
- engine_lbl = gr.Textbox(label="Model Used", interactive=False)
 
 
 
 
 
 
146
 
147
- run_btn.click(predict, [input_img], [processed_img, output_text, status_lbl, engine_lbl])
148
 
149
  if __name__ == "__main__":
150
  demo.launch()
 
1
  import os
2
  import io
3
+ import time
4
  import gradio as gr
5
  import torch
6
  import numpy as np
 
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ # --- ENGINE CORE ---
22
+ print("System: Initializing DevGen Premium Engine...")
23
  processor = TrOCRProcessor.from_pretrained(BASE_MODEL_ID)
24
  base_model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL_ID)
25
 
 
33
  peft_model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
34
  try:
35
  model = peft_model.merge_and_unload()
 
36
  except Exception:
37
  model = peft_model
38
  model.to(device)
39
  model.eval()
40
 
 
41
  cnn_engine = CharacterClassifier(model_path=CNN_MODEL_PATH, device=device)
42
 
43
+ # --- ORIGINAL ROUTING ---
44
  def _flood_fill(binary, visited, start_y, start_x, h, w):
45
  stack = [(start_y, start_x)]
46
  size = 0
 
74
  cmin, cmax = np.where(cols)[0][[0, -1]]
75
  w, h = cmax - cmin + 1, rmax - rmin + 1
76
  ar, bc = w/h, count_blobs(binary, min_size=max(binary.size * 0.001, 10))
77
+
78
  is_char = True
79
  if ar > 2.5: is_char = False
80
  elif ar > 1.8 and bc >= 3: is_char = False
81
  elif bc >= 4: is_char = False
82
  elif ar < 1.3 and bc <= 2: is_char = True
83
  elif bc == 1 and ar < 1.5: is_char = True
84
+ elif ar < 1.75 and bc <= 2: is_char = True
85
  elif ar > 1.6: is_char = False
86
+
87
  return ("character" if is_char else "word"), ar, bc
88
 
89
+ # --- PIPELINE ---
90
+ def predict(image, manual_mode):
91
  if image is None: return None, None, "Upload image.", ""
92
 
 
93
  buf = io.BytesIO()
94
  image.save(buf, format="PNG")
95
+ preprocessed_pil = preprocess_for_ocr(buf.getvalue())
 
 
 
96
 
97
+ if manual_mode == "Automatic":
98
+ mode, ar, bc = original_classify_input(preprocessed_pil)
99
+ status = f"**System Insight**: Auto-detected **{mode.upper()}** (AR: {ar:.2f}, Blobs: {bc})"
100
+ else:
101
+ mode = manual_mode.lower()
102
+ status = f"**System Insight**: Manual Override set to **{mode.upper()}**"
103
 
104
  try:
105
  if mode == "character" and cnn_engine.available:
 
108
  else:
109
  pixel_values = processor(preprocessed_pil, return_tensors="pt").pixel_values.to(device)
110
  with torch.no_grad():
111
+ gen = model.generate(pixel_values, num_beams=4, max_length=128, early_stopping=True, decoder_start_token_id=model.config.decoder_start_token_id)
112
+ text = processor.batch_decode(gen, skip_special_tokens=True)[0]
 
 
 
 
 
 
113
  return preprocessed_pil, text, status, "TrOCR + LoRA"
114
  except Exception as e:
115
+ return preprocessed_pil, f"Inference Error: {str(e)}", "Process Failed", "None"
116
 
117
+ # --- PREMIUM CSS ---
118
  CSS = """
119
+ @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@400;600&family=Inter:wght@400;500&display=swap');
120
+ .gradio-container { background: linear-gradient(135deg, #0f172a 0%, #1e1b4b 100%) !important; color: white !important; font-family: 'Inter', sans-serif !important; }
121
+ .premium-card { background: rgba(30, 41, 59, 0.7) !important; backdrop-filter: blur(12px); border: 1px solid rgba(255,255,255,0.1); border-radius: 24px; padding: 2rem; box-shadow: 0 25px 50px -12px rgba(0,0,0,0.5); margin-bottom: 20px; }
122
+ h1 { font-family: 'Outfit', sans-serif; font-size: 3rem !important; font-weight: 600; background: linear-gradient(90deg, #818cf8, #c084fc); -webkit-background-clip: text; -webkit-fill-color: transparent; margin-bottom: 1rem; }
123
+ .result-box { font-size: 2.5rem !important; font-weight: 600; text-align: center; color: #818cf8; background: rgba(0,0,0,0.2) !important; border: 1px solid rgba(129, 140, 248, 0.3) !important; border-radius: 16px !important; }
124
+ .btn-primary { background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%) !important; border: none !important; border-radius: 12px !important; font-family: 'Outfit', sans-serif !important; font-weight: 600 !important; }
125
  """
126
 
127
+ with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
128
+ with gr.Column(elem_classes="premium-card"):
129
+ gr.Markdown("# ๐Ÿ•‰๏ธ DevGen OCR")
130
+ gr.Markdown("A high-fidelity neuro-generative OCR suite for Devanagari.")
131
+
132
+ with gr.Row():
133
+ with gr.Column(scale=1):
134
+ img_in = gr.Image(type="pil", label="Input Document", mirror_webcam=False)
135
+ mode_ctrl = gr.Radio(["Automatic", "Word", "Character"], value="Automatic", label="Recognition Logic")
136
+ sub_btn = gr.Button("Recognize Handwriting", variant="primary", elem_classes="btn-primary")
 
 
137
 
138
+ with gr.Column(scale=1):
139
+ text_out = gr.Textbox(label="Recognition Result", elem_classes="result-box", interactive=False)
140
+ status_md = gr.Markdown("Engine is ready.")
141
+ engine_txt = gr.Textbox(label="Active Model", interactive=False)
142
+
143
+ with gr.Accordion("๐Ÿ› ๏ธ Technical Diagnostics", open=False):
144
+ with gr.Row(elem_classes="premium-card"):
145
+ img_proc = gr.Image(type="pil", label="Preprocessed Input (What the model sees)", interactive=False)
146
+ gr.Markdown("### Processing Notes\nThis view shows the image after binarization and aspect-ratio normalization. If the image here is blurry or cut off, it may affect accuracy.")
147
 
148
+ sub_btn.click(predict, [img_in, mode_ctrl], [img_proc, text_out, status_md, engine_txt])
149
 
150
  if __name__ == "__main__":
151
  demo.launch()