manishw7 commited on
Commit
ae3fa31
·
1 Parent(s): adb49d0

Feature: Full Smart Pipeline with Auto-Routing and Preprocessing

Browse files
Files changed (1) hide show
  1. app.py +112 -63
app.py CHANGED
@@ -2,23 +2,23 @@ import os
2
  import gradio as gr
3
  import torch
4
  import numpy as np
 
5
  from PIL import Image
6
  from peft import PeftModel
7
  from transformers import AutoTokenizer, TrOCRProcessor, ViTImageProcessor, VisionEncoderDecoderModel
8
- from cnn_model import CharacterClassifier # Importing your CNN logic
9
 
10
  # --- CONFIGURATION ---
11
  BASE_MODEL_ID = "paudelanil/trocr-devanagari-2"
12
  ADAPTER_ID = "manishw10/devgen-trocr-devanagari-lora"
13
  CNN_MODEL_PATH = "devanagari-cnn-classifier.pt"
14
 
15
- # Detect environment
16
  IS_SPACE = "SPACE_ID" in os.environ
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
- print(f"System: Initializing Models (Env: {'Hugging Face Space' if IS_SPACE else 'Local'})")
 
20
 
21
- # 1. Load TrOCR Model & Processor
22
  try:
23
  processor = TrOCRProcessor.from_pretrained(BASE_MODEL_ID)
24
  except Exception:
@@ -31,77 +31,126 @@ model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
31
  model.to(device)
32
  model.eval()
33
 
34
- # 2. Load CNN Classifier
35
  cnn_engine = CharacterClassifier(model_path=CNN_MODEL_PATH, device=device)
36
-
37
  print(f"System: Models loaded successfully on {device}")
38
 
39
- def predict_trocr(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if image is None:
41
- return "Error: No image uploaded"
 
42
  try:
43
- image = image.convert("RGB")
44
- pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
 
45
 
46
- # --- HIGH-QUALITY GENERATION ---
47
- # Added num_beams and length_penalty to fix the "rubbish" output.
48
- # This makes TrOCR use Beam Search instead of Greedy Search.
49
- with torch.no_grad():
50
- generated_ids = model.base_model.generate(
51
- pixel_values=pixel_values,
52
- num_beams=4,
53
- length_penalty=1.0,
54
- max_new_tokens=64,
55
- early_stopping=True,
56
- decoder_start_token_id=model.config.decoder_start_token_id
57
- )
58
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
59
- return generated_text
 
 
 
 
 
 
60
  except Exception as e:
61
  import traceback
62
  print(traceback.format_exc())
63
- return f"TrOCR Error: {str(e)}"
64
 
65
- def predict_cnn(image):
66
- if image is None:
67
- return "Error: No image uploaded"
68
- try:
69
- image = image.convert("RGB")
70
- result = cnn_engine.predict(image)
71
- if "error" in result:
72
- return result["error"]
73
- return f"Character: {result['text']} (Confidence: {result['confidence']:.2%})"
74
- except Exception as e:
75
- return f"CNN Error: {str(e)}"
76
-
77
- # --- CUSTOM GRADIO INTERFACE ---
78
- with gr.Blocks(title="DevGen OCR Suite") as demo:
79
- gr.Markdown("# 🕉️ DevGen Devanagari OCR Suite")
80
- gr.Markdown("Switch between TrOCR (for words/sentences) and CNN (for single characters).")
81
 
82
- with gr.Tabs():
83
- with gr.TabItem("TrOCR (Word/Sentence Recognition)"):
84
- with gr.Row():
85
- with gr.Column():
86
- img_input = gr.Image(type="pil", label="Upload Handwritten Word")
87
- btn_trocr = gr.Button("Recognize Word", variant="primary")
88
- with gr.Column():
89
- text_output = gr.Textbox(label="Recognized Text")
90
- btn_trocr.click(fn=predict_trocr, inputs=img_input, outputs=text_output)
91
 
92
- with gr.TabItem("CNN (Single Character Recognition)"):
93
- with gr.Row():
94
- with gr.Column():
95
- char_input = gr.Image(type="pil", label="Upload Single Character")
96
- btn_cnn = gr.Button("Classify Character", variant="primary")
97
- with gr.Column():
98
- char_output = gr.Textbox(label="Classification Result")
99
- btn_cnn.click(fn=predict_cnn, inputs=char_input, outputs=char_output)
100
 
101
- gr.Markdown("---")
102
- gr.Markdown("Built with ❤️ by DevGen Team. Using TrOCR + LoRA and custom 3-layer CNN.")
 
 
 
 
 
 
 
 
103
 
104
  if __name__ == "__main__":
105
- server_name = "0.0.0.0" if IS_SPACE else "127.0.0.1"
106
- # Note: We don't use monkey-patching here, the base_model.generate handles it.
107
- demo.launch(server_name=server_name)
 
2
  import gradio as gr
3
  import torch
4
  import numpy as np
5
+ import cv2
6
  from PIL import Image
7
  from peft import PeftModel
8
  from transformers import AutoTokenizer, TrOCRProcessor, ViTImageProcessor, VisionEncoderDecoderModel
9
+ from cnn_model import CharacterClassifier
10
 
11
  # --- CONFIGURATION ---
12
  BASE_MODEL_ID = "paudelanil/trocr-devanagari-2"
13
  ADAPTER_ID = "manishw10/devgen-trocr-devanagari-lora"
14
  CNN_MODEL_PATH = "devanagari-cnn-classifier.pt"
15
 
 
16
  IS_SPACE = "SPACE_ID" in os.environ
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
+ # --- MODEL INITIALIZATION ---
20
+ print(f"System: Initializing Smart Engine (Env: {'Hugging Face Space' if IS_SPACE else 'Local'})")
21
 
 
22
  try:
23
  processor = TrOCRProcessor.from_pretrained(BASE_MODEL_ID)
24
  except Exception:
 
31
  model.to(device)
32
  model.eval()
33
 
 
34
  cnn_engine = CharacterClassifier(model_path=CNN_MODEL_PATH, device=device)
 
35
  print(f"System: Models loaded successfully on {device}")
36
 
37
+ # --- SMART ROUTING LOGIC ---
38
+ def _count_blobs(binary, min_size=10):
39
+ h, w = binary.shape
40
+ visited = np.zeros_like(binary, dtype=bool)
41
+ count = 0
42
+ for y in range(h):
43
+ for x in range(w):
44
+ if binary[y, x] and not visited[y, x]:
45
+ # Simple iterative flood fill
46
+ stack = [(y, x)]
47
+ size = 0
48
+ while stack:
49
+ py, px = stack.pop()
50
+ if py<0 or py>=h or px<0 or px>=w or visited[py, px] or not binary[py, px]:
51
+ continue
52
+ visited[py, px] = True
53
+ size += 1
54
+ stack.extend([(py+1, px), (py-1, px), (py, px+1), (py, px-1)])
55
+ if size >= min_size:
56
+ count += 1
57
+ return count
58
+
59
+ def classify_input(image):
60
+ gray = image.convert("L")
61
+ arr = np.array(gray)
62
+ threshold = min(arr.mean() * 0.75, 200)
63
+ binary = (arr < threshold).astype(np.uint8)
64
+
65
+ rows = np.any(binary, axis=1)
66
+ cols = np.any(binary, axis=0)
67
+ if not rows.any() or not cols.any():
68
+ return "character", 0.5, "no_ink"
69
+
70
+ rmin, rmax = np.where(rows)[0][[0, -1]]
71
+ cmin, cmax = np.where(cols)[0][[0, -1]]
72
+ aspect_ratio = (cmax - cmin + 1) / max(rmax - rmin + 1, 1)
73
+ blob_count = _count_blobs(binary, min_size=max(binary.size * 0.001, 10))
74
+
75
+ if aspect_ratio > 1.8 or blob_count >= 3:
76
+ return "word", 0.9, "wide_or_multiple_blobs"
77
+ return "character", 0.8, "square_compact"
78
+
79
+ # --- PREPROCESSING ---
80
+ def preprocess_for_trocr(image):
81
+ # Standard cleanup for word recognition
82
+ image = image.convert("RGB")
83
+ # Tightly crop to ink
84
+ gray = np.array(image.convert("L"))
85
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
86
+ coords = np.column_stack(np.where(binary > 0))
87
+ if len(coords) > 0:
88
+ y0, x0 = coords.min(axis=0)
89
+ y1, x1 = coords.max(axis=0)
90
+ # Pad slightly
91
+ image = image.crop((max(0, x0-10), max(0, y0-10), min(image.width, x1+10), min(image.height, y1+10)))
92
+ return image
93
+
94
+ # --- MAIN INFERENCE PIPELINE ---
95
+ def smart_predict(image):
96
  if image is None:
97
+ return "Please upload an image.", "Waiting...", "None"
98
+
99
  try:
100
+ # 1. Smart Routing
101
+ input_type, confidence, reason = classify_input(image)
102
+ system_status = f"Mode: {input_type.upper()} | Reason: {reason} (Conf: {confidence:.0%})"
103
 
104
+ if input_type == "character" and cnn_engine.available:
105
+ # 2. CNN Pipeline
106
+ result = cnn_engine.predict(image)
107
+ return result["text"], system_status, "CNN Classifier"
108
+ else:
109
+ # 3. TrOCR Pipeline
110
+ image_cleaned = preprocess_for_trocr(image)
111
+ pixel_values = processor(image_cleaned, return_tensors="pt").pixel_values.to(device)
112
+ with torch.no_grad():
113
+ generated_ids = model.base_model.generate(
114
+ pixel_values=pixel_values,
115
+ num_beams=4,
116
+ length_penalty=1.0,
117
+ max_new_tokens=64,
118
+ early_stopping=True,
119
+ decoder_start_token_id=model.config.decoder_start_token_id
120
+ )
121
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
122
+ return text, system_status, "TrOCR + LoRA"
123
+
124
  except Exception as e:
125
  import traceback
126
  print(traceback.format_exc())
127
+ return f"Error: {str(e)}", "System Failure", "Error"
128
 
129
+ # --- INTERFACE ---
130
+ with gr.Blocks(theme=gr.themes.Soft(), title="DevGen Smart OCR") as demo:
131
+ gr.Markdown("# 🕉️ DevGen Smart Devanagari OCR")
132
+ gr.Markdown("Automatic detection and recognition for both single characters and full words.")
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ with gr.Row():
135
+ with gr.Column(scale=1):
136
+ input_img = gr.Image(type="pil", label="Upload Handwritten Input")
137
+ submit_btn = gr.Button("Recognize", variant="primary")
 
 
 
 
 
138
 
139
+ with gr.Column(scale=1):
140
+ output_text = gr.Textbox(label="Recognized Text", placeholder="Result will appear here...", interactive=False)
141
+ status_text = gr.Label(label="Engine Status")
142
+ model_used = gr.Textbox(label="Model Used", interactive=False)
 
 
 
 
143
 
144
+ submit_btn.click(
145
+ fn=smart_predict,
146
+ inputs=input_img,
147
+ outputs=[output_text, status_text, model_used]
148
+ )
149
+
150
+ gr.Examples(
151
+ examples=[], # You can add local test images here
152
+ inputs=input_img
153
+ )
154
 
155
  if __name__ == "__main__":
156
+ demo.launch()