Spaces:
Sleeping
Sleeping
File size: 6,845 Bytes
4fe11e0 dc17282 e74b0f8 adb49d0 6dc6d66 e74b0f8 dc17282 ae3fa31 dc17282 e74b0f8 dc17282 e74b0f8 adb49d0 e74b0f8 6cd700e 68beaa2 dc17282 e74b0f8 9487c54 dc17282 9487c54 6cd700e adb49d0 ae3fa31 6cd700e e8bc8af 6cd700e e8bc8af 6cd700e e8bc8af 6cd700e e8bc8af dc17282 6cd700e ae3fa31 dc17282 42d5462 6cd700e 42d5462 6cd700e dc17282 9ebb598 dc17282 42d5462 6cd700e 68beaa2 6cd700e 42d5462 9ebb598 42d5462 ced8950 6cd700e 9ebb598 6cd700e 68beaa2 9ebb598 68beaa2 9487c54 6cd700e 9487c54 6cd700e 9487c54 6cd700e 9487c54 6cd700e b53bf6e 68beaa2 6dc6d66 68beaa2 6dc6d66 68beaa2 9ebb598 68beaa2 6cd700e 9ebb598 e8bc8af 68beaa2 6cd700e e8bc8af 68beaa2 e8bc8af ced8950 e74b0f8 6cd700e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import os
import io
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
from peft import PeftModel
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from cnn_model import CharacterClassifier
from preprocessing import preprocess_for_ocr
# --- CONFIGURATION ---
BASE_MODEL_ID = "paudelanil/trocr-devanagari-2"
ADAPTER_ID = "manishw10/devgen-trocr-devanagari-lora"
CNN_MODEL_PATH = "devanagari-cnn-classifier.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- ENGINE INITIALIZATION ---
print("System: Initializing Stable Premium Engine (3.50.2)...")
processor = TrOCRProcessor.from_pretrained(BASE_MODEL_ID)
base_model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL_ID)
base_model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
base_model.config.pad_token_id = processor.tokenizer.pad_token_id
base_model.config.eos_token_id = processor.tokenizer.sep_token_id
base_model.config.vocab_size = base_model.config.decoder.vocab_size
peft_model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
try:
model = peft_model.merge_and_unload()
except Exception:
model = peft_model
model.to(device); model.eval()
cnn_engine = CharacterClassifier(model_path=CNN_MODEL_PATH, device=device)
# --- ORIGINAL ROUTING LOGIC ---
def _flood_fill(binary, visited, start_y, start_x, h, w):
stack = [(start_y, start_x)]
size = 0
while stack:
y, x = stack.pop()
if y<0 or y>=h or x<0 or x>=w or visited[y,x] or not binary[y,x]: continue
visited[y,x] = True; size += 1
stack.extend([(y+1,x),(y-1,x),(y,x+1),(y,x-1)])
return size
def count_blobs(binary):
h, w = binary.shape; visited = np.zeros_like(binary, dtype=bool); count = 0
for y in range(h):
for x in range(w):
if binary[y,x] and not visited[y,x]:
size = _flood_fill(binary, visited, y, x, h, w)
if size >= max(binary.size * 0.001, 10): count += 1
return count
def original_classify_input(image):
gray = image.convert("L"); arr = np.array(gray)
threshold = min(arr.mean() * 0.75, 200)
binary = (arr < threshold).astype(np.uint8)
rows, cols = np.any(binary, axis=1), np.any(binary, axis=0)
if not rows.any() or not cols.any(): return "character", 1.0, 1
y0, x0 = np.where(rows)[0][0], np.where(cols)[0][0]
y1, x1 = np.where(rows)[0][-1], np.where(cols)[0][-1]
w, h = x1-x0+1, y1-y0+1
ar, bc = w/h, count_blobs(binary)
is_char = True
if ar > 2.5: is_char = False
elif ar > 1.8 and bc >= 3: is_char = False
elif bc >= 4: is_char = False
elif ar < 1.3 and bc <= 2: is_char = True
elif bc == 1 and ar < 1.5: is_char = True
elif ar < 1.75 and bc <= 2: is_char = True
elif ar > 1.6: is_char = False
return ("character" if is_char else "word"), ar, bc
def get_confidence_html(confidence):
color = "#10b981" if confidence > 0.9 else "#f59e0b" if confidence > 0.7 else "#ef4444"
return f"""<div style="display: flex; flex-direction: column; align-items: center; background: rgba(0,0,0,0.2); border-radius: 20px; padding: 15px;">
<svg width="100" height="100" viewBox="0 0 100 100">
<circle cx="50" cy="50" r="45" fill="none" stroke="rgba(255,255,255,0.1)" stroke-width="8" />
<circle cx="50" cy="50" r="45" fill="none" stroke="{color}" stroke-width="8" stroke-dasharray="282.7" stroke-dashoffset="{282.7 * (1 - confidence)}" stroke-linecap="round" />
<text x="50" y="55" font-family="Arial" font-size="20" font-weight="bold" fill="{color}" text-anchor="middle">{int(confidence * 100)}%</text>
</svg>
</div>"""
# --- PREDICT ---
def predict(image, manual_mode):
if image is None: return None, None, "Upload image.", "", ""
buf = io.BytesIO(); image.save(buf, format="PNG")
pre_pil = preprocess_for_ocr(buf.getvalue())
if manual_mode == "Automatic":
mode, ar, bc = original_classify_input(pre_pil)
status = f"System: {mode.upper()} (AR: {ar:.2f}, Blobs: {bc})"
else:
mode = manual_mode.lower(); status = f"Manual Mode: {mode.upper()}"
try:
if mode == "character" and cnn_engine.available:
res = cnn_engine.predict(pre_pil)
return pre_pil, res["text"], status, "CNN Classifier", get_confidence_html(res["confidence"])
else:
pixel_values = processor(pre_pil, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
out = model.generate(pixel_values, num_beams=4, max_length=128, early_stopping=True, return_dict_in_generate=True, output_scores=True, decoder_start_token_id=model.config.decoder_start_token_id)
scores = torch.exp(model.compute_transition_scores(out.sequences, out.scores, normalize_logits=True)[0])
txt = processor.batch_decode(out.sequences, skip_special_tokens=True)[0]
return pre_pil, txt, status, "TrOCR + LoRA", get_confidence_html(float(scores.mean().item()))
except Exception as e:
return pre_pil, f"Error: {str(e)}", "Failed", "None", ""
# --- PREMIUM CSS (Gradio 3.x Optimized) ---
CSS = """
.gradio-container { background: linear-gradient(135deg, #0f172a 0%, #1e1b4b 100%) !important; color: white !important; }
.premium-card { background: rgba(30, 41, 59, 0.7) !important; border: 1px solid rgba(255,255,255,0.1); border-radius: 20px; padding: 20px; box-shadow: 0 10px 30px rgba(0,0,0,0.5); }
.result-box textarea { font-size: 2.5rem !important; font-weight: bold !important; color: #818cf8 !important; text-align: center !important; background: transparent !important; border: none !important; }
h1 { color: #818cf8 !important; font-size: 2.5rem !important; }
"""
with gr.Blocks(css=CSS) as demo:
with gr.Column(elem_classes="premium-card"):
gr.Markdown("# 🕉️ DevGen OCR")
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(type="pil", label="Input")
mode_ctrl = gr.Radio(["Automatic", "Word", "Character"], value="Automatic", label="Mode")
sub_btn = gr.Button("Recognize", variant="primary")
with gr.Column(scale=1):
conf_html = gr.HTML()
text_out = gr.Textbox(label="Result", elem_classes="result-box", interactive=False)
status_md = gr.Markdown("Ready.")
engine_txt = gr.Textbox(label="Model", interactive=False)
with gr.Column():
gr.Markdown("### 🛠️ Visual Debug: What the Model Sees")
img_proc = gr.Image(type="pil", label="Preprocessed", interactive=False)
sub_btn.click(predict, [img_in, mode_ctrl], [img_proc, text_out, status_md, engine_txt, conf_html])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|