manga1 / app.py
sayed555's picture
Update app.py
2a65b68 verified
import os
import sys
import json
import numpy as np
import cv2
import traceback
import shutil
import gradio as gr
os.makedirs('models/textdetector', exist_ok=True)
os.makedirs('models/ocr', exist_ok=True)
os.makedirs('models/inpainting', exist_ok=True)
models_mapping = {
'comictextdetector.pt': 'models/textdetector/comictextdetector.pt',
'mit48pxctc_ocr.ckpt': 'models/ocr/mit48px_ctc.ckpt',
'lama_large_512px.ckpt': 'models/inpainting/lama_large_512px.ckpt'
}
for src, dst in models_mapping.items():
if os.path.exists(src) and not os.path.exists(dst):
try:
shutil.move(src, dst)
except Exception:
pass
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
detector_model = None
ocr_model = None
# ==========================================
# دالة التحديد (ترجع مضلعات)
# ==========================================
def detect_boxes(image):
global detector_model
try:
if detector_model is None:
from modules.base import load_modules
load_modules()
from modules.textdetector import TEXTDETECTORS
if 'ctd' in TEXTDETECTORS.module_dict:
detector_model = TEXTDETECTORS.module_dict['ctd']()
else:
return json.dumps([{"error": "موديل التحديد ctd لم يتم تحميله. تأكد من إضافة setuptools في requirements"}])
if hasattr(detector_model, 'load_model') and not detector_model.all_model_loaded():
detector_model.load_model()
if image.shape[2] == 3:
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
else:
image_bgr = image
mask, boxes = detector_model.detect(image_bgr)
real_boxes = []
for i, b in enumerate(boxes):
pts = None
if hasattr(b, 'box') and len(b.box) == 4:
x, y = b.box[0], b.box[1]
w = b.box[2] - b.box[0]
h = b.box[3] - b.box[1]
pts = np.array([[x, y], [x + w, y], [x + w, y + h], [x, y + h]], dtype=np.float32)
elif hasattr(b, 'pts') or hasattr(b, 'lines'):
pts = np.array(b.pts if hasattr(b, 'pts') else b.lines, dtype=np.float32)
if pts is None: continue
x, y, w, h = cv2.boundingRect(pts)
real_boxes.append({
"id": i,
"x": int(x), "y": int(y), "w": int(w), "h": int(h),
"polygon": pts.tolist() # إرسال المضلع لصفحة الويب
})
return json.dumps(real_boxes)
except Exception as e:
return json.dumps([{"error": f"===== خطأ في التحديد =====\n\n{traceback.format_exc()}"}])
# ==========================================
# دالة الاستخراج (تستقبل مضلعات)
# ==========================================
def extract_text(image, boxes_json):
global ocr_model
try:
if not boxes_json or str(boxes_json).strip() == "":
return "لم يتم إرسال أي مربعات! ارسم أو حدد المربعات أولاً."
boxes_data = json.loads(boxes_json)
if not boxes_data: return "لم يتم العثور على مربعات."
if len(boxes_data) > 0 and 'error' in boxes_data[0]:
return boxes_data[0]['error']
from utils.textblock import TextBlock
blk_list = []
for b in boxes_data:
if 'error' in b: continue
blk = TextBlock()
# قراءة المضلع من صفحة الويب
if 'polygon' in b and b['polygon']:
pts = np.array(b['polygon'], dtype=np.float32)
x, y, w, h = cv2.boundingRect(pts)
blk.box = [x, y, x + w, y + h]
blk.polygon = pts
blk.lines = [pts]
else:
x, y, w, h = int(b['x']), int(b['y']), int(b['w']), int(b['h'])
blk.box = [x, y, x + w, y + h]
pts = np.array([[x, y], [x + w, y], [x + w, y + h], [x, y + h]], dtype=np.float32)
blk.polygon = pts
blk.lines = [pts]
blk_list.append(blk)
if ocr_model is None:
from modules.base import load_modules
load_modules()
from modules.ocr import OCR
target_key = None
for key in OCR.module_dict.keys():
if 'ctc' in key.lower() or '48' in key.lower():
target_key = key
break
if not target_key and OCR.module_dict:
target_key = list(OCR.module_dict.keys())[0]
ocr_model = OCR.module_dict[target_key]()
if hasattr(ocr_model, 'all_model_loaded') and not ocr_model.all_model_loaded():
ocr_model.load_model()
if image.shape[2] == 3:
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
else:
image_bgr = image
results = ocr_model.run_ocr(image_bgr, blk_list)
final_text = []
if results:
for r in results:
if isinstance(r.text, list): final_text.append(" ".join(r.text))
else: final_text.append(str(r.text))
res_text = "\n\n".join(final_text)
if not res_text.strip():
return "الموديل لم يستطع قراءة أي نص. (تأكد أن المضلع محيط بالنص بشكل صحيح)"
return res_text
except Exception as e:
return f"===== خطأ برمجي =====\n\n{traceback.format_exc()}"
with gr.Blocks() as demo:
img_input = gr.Image(type="numpy")
boxes_input = gr.Textbox()
detect_out = gr.Textbox()
extract_out = gr.Textbox()
detect_btn = gr.Button("detect")
extract_btn = gr.Button("extract")
detect_btn.click(fn=detect_boxes, inputs=img_input, outputs=detect_out, api_name="detect")
extract_btn.click(fn=extract_text, inputs=[img_input, boxes_input], outputs=extract_out, api_name="extract")
if __name__ == "__main__":
demo.launch()