File size: 6,845 Bytes
4fe11e0
dc17282
e74b0f8
 
adb49d0
6dc6d66
e74b0f8
 
dc17282
ae3fa31
dc17282
e74b0f8
dc17282
e74b0f8
adb49d0
 
e74b0f8
 
6cd700e
68beaa2
dc17282
e74b0f8
9487c54
 
 
 
 
 
 
 
dc17282
9487c54
6cd700e
adb49d0
ae3fa31
6cd700e
e8bc8af
 
 
 
 
 
6cd700e
e8bc8af
 
 
6cd700e
 
e8bc8af
 
 
 
6cd700e
e8bc8af
 
dc17282
6cd700e
ae3fa31
 
dc17282
42d5462
6cd700e
 
42d5462
6cd700e
dc17282
 
 
 
 
 
9ebb598
dc17282
 
 
42d5462
 
6cd700e
 
 
68beaa2
 
6cd700e
 
42d5462
 
9ebb598
42d5462
ced8950
6cd700e
9ebb598
6cd700e
68beaa2
9ebb598
68beaa2
9487c54
 
6cd700e
 
9487c54
6cd700e
9487c54
6cd700e
 
 
 
9487c54
6cd700e
b53bf6e
68beaa2
6dc6d66
68beaa2
 
 
 
6dc6d66
 
68beaa2
9ebb598
 
 
 
68beaa2
 
6cd700e
9ebb598
e8bc8af
68beaa2
 
 
6cd700e
e8bc8af
68beaa2
e8bc8af
ced8950
e74b0f8
 
6cd700e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import io
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
from peft import PeftModel
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from cnn_model import CharacterClassifier
from preprocessing import preprocess_for_ocr

# --- CONFIGURATION ---
BASE_MODEL_ID = "paudelanil/trocr-devanagari-2"
ADAPTER_ID = "manishw10/devgen-trocr-devanagari-lora"
CNN_MODEL_PATH = "devanagari-cnn-classifier.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- ENGINE INITIALIZATION ---
print("System: Initializing Stable Premium Engine (3.50.2)...")
processor = TrOCRProcessor.from_pretrained(BASE_MODEL_ID)
base_model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL_ID)
base_model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
base_model.config.pad_token_id = processor.tokenizer.pad_token_id
base_model.config.eos_token_id = processor.tokenizer.sep_token_id
base_model.config.vocab_size = base_model.config.decoder.vocab_size

peft_model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
try:
    model = peft_model.merge_and_unload()
except Exception:
    model = peft_model
model.to(device); model.eval()
cnn_engine = CharacterClassifier(model_path=CNN_MODEL_PATH, device=device)

# --- ORIGINAL ROUTING LOGIC ---
def _flood_fill(binary, visited, start_y, start_x, h, w):
    stack = [(start_y, start_x)]
    size = 0
    while stack:
        y, x = stack.pop()
        if y<0 or y>=h or x<0 or x>=w or visited[y,x] or not binary[y,x]: continue
        visited[y,x] = True; size += 1
        stack.extend([(y+1,x),(y-1,x),(y,x+1),(y,x-1)])
    return size

def count_blobs(binary):
    h, w = binary.shape; visited = np.zeros_like(binary, dtype=bool); count = 0
    for y in range(h):
        for x in range(w):
            if binary[y,x] and not visited[y,x]:
                size = _flood_fill(binary, visited, y, x, h, w)
                if size >= max(binary.size * 0.001, 10): count += 1
    return count

def original_classify_input(image):
    gray = image.convert("L"); arr = np.array(gray)
    threshold = min(arr.mean() * 0.75, 200)
    binary = (arr < threshold).astype(np.uint8)
    rows, cols = np.any(binary, axis=1), np.any(binary, axis=0)
    if not rows.any() or not cols.any(): return "character", 1.0, 1
    y0, x0 = np.where(rows)[0][0], np.where(cols)[0][0]
    y1, x1 = np.where(rows)[0][-1], np.where(cols)[0][-1]
    w, h = x1-x0+1, y1-y0+1
    ar, bc = w/h, count_blobs(binary)
    is_char = True
    if ar > 2.5: is_char = False
    elif ar > 1.8 and bc >= 3: is_char = False
    elif bc >= 4: is_char = False
    elif ar < 1.3 and bc <= 2: is_char = True
    elif bc == 1 and ar < 1.5: is_char = True
    elif ar < 1.75 and bc <= 2: is_char = True
    elif ar > 1.6: is_char = False
    return ("character" if is_char else "word"), ar, bc

def get_confidence_html(confidence):
    color = "#10b981" if confidence > 0.9 else "#f59e0b" if confidence > 0.7 else "#ef4444"
    return f"""<div style="display: flex; flex-direction: column; align-items: center; background: rgba(0,0,0,0.2); border-radius: 20px; padding: 15px;">
        <svg width="100" height="100" viewBox="0 0 100 100">
            <circle cx="50" cy="50" r="45" fill="none" stroke="rgba(255,255,255,0.1)" stroke-width="8" />
            <circle cx="50" cy="50" r="45" fill="none" stroke="{color}" stroke-width="8" stroke-dasharray="282.7" stroke-dashoffset="{282.7 * (1 - confidence)}" stroke-linecap="round" />
            <text x="50" y="55" font-family="Arial" font-size="20" font-weight="bold" fill="{color}" text-anchor="middle">{int(confidence * 100)}%</text>
        </svg>
    </div>"""

# --- PREDICT ---
def predict(image, manual_mode):
    if image is None: return None, None, "Upload image.", "", ""
    buf = io.BytesIO(); image.save(buf, format="PNG")
    pre_pil = preprocess_for_ocr(buf.getvalue())
    if manual_mode == "Automatic":
        mode, ar, bc = original_classify_input(pre_pil)
        status = f"System: {mode.upper()} (AR: {ar:.2f}, Blobs: {bc})"
    else:
        mode = manual_mode.lower(); status = f"Manual Mode: {mode.upper()}"
    try:
        if mode == "character" and cnn_engine.available:
            res = cnn_engine.predict(pre_pil)
            return pre_pil, res["text"], status, "CNN Classifier", get_confidence_html(res["confidence"])
        else:
            pixel_values = processor(pre_pil, return_tensors="pt").pixel_values.to(device)
            with torch.no_grad():
                out = model.generate(pixel_values, num_beams=4, max_length=128, early_stopping=True, return_dict_in_generate=True, output_scores=True, decoder_start_token_id=model.config.decoder_start_token_id)
            scores = torch.exp(model.compute_transition_scores(out.sequences, out.scores, normalize_logits=True)[0])
            txt = processor.batch_decode(out.sequences, skip_special_tokens=True)[0]
            return pre_pil, txt, status, "TrOCR + LoRA", get_confidence_html(float(scores.mean().item()))
    except Exception as e:
        return pre_pil, f"Error: {str(e)}", "Failed", "None", ""

# --- PREMIUM CSS (Gradio 3.x Optimized) ---
CSS = """
.gradio-container { background: linear-gradient(135deg, #0f172a 0%, #1e1b4b 100%) !important; color: white !important; }
.premium-card { background: rgba(30, 41, 59, 0.7) !important; border: 1px solid rgba(255,255,255,0.1); border-radius: 20px; padding: 20px; box-shadow: 0 10px 30px rgba(0,0,0,0.5); }
.result-box textarea { font-size: 2.5rem !important; font-weight: bold !important; color: #818cf8 !important; text-align: center !important; background: transparent !important; border: none !important; }
h1 { color: #818cf8 !important; font-size: 2.5rem !important; }
"""

with gr.Blocks(css=CSS) as demo:
    with gr.Column(elem_classes="premium-card"):
        gr.Markdown("# 🕉️ DevGen OCR")
        with gr.Row():
            with gr.Column(scale=1):
                img_in = gr.Image(type="pil", label="Input")
                mode_ctrl = gr.Radio(["Automatic", "Word", "Character"], value="Automatic", label="Mode")
                sub_btn = gr.Button("Recognize", variant="primary")
            with gr.Column(scale=1):
                conf_html = gr.HTML()
                text_out = gr.Textbox(label="Result", elem_classes="result-box", interactive=False)
                status_md = gr.Markdown("Ready.")
                engine_txt = gr.Textbox(label="Model", interactive=False)
        with gr.Column():
            gr.Markdown("### 🛠️ Visual Debug: What the Model Sees")
            img_proc = gr.Image(type="pil", label="Preprocessed", interactive=False)

    sub_btn.click(predict, [img_in, mode_ctrl], [img_proc, text_out, status_md, engine_txt, conf_html])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)