Ehtesham123's picture
Update app.py
b4cebb1 verified
import gradio as gr
from STD_detect import OBBPredictor
from STR_recognize import TextRecognizer
import os
# Initialize models
STD_model_path = "pretrained_model/STD.pt"
STR_ckpt_path = "pretrained_model/STR.ckpt"
detector = OBBPredictor(STD_model_path)
recognizer = TextRecognizer(STR_ckpt_path, device='cpu') # or 'cuda' if on GPU
# ==== OCR pipeline function ====
def run_pipeline(image):
crops = detector.predict(image)
recognized_texts = [recognizer.recognize(crop) for crop in crops]
final_output = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(recognized_texts)])
return final_output if recognized_texts else "No text detected."
# ==== Get sample image paths ====
example_images = [f"samples/{f}" for f in os.listdir("samples") if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
# ==== Gradio app with ONLY sample images ====
demo = gr.Interface(
fn=run_pipeline,
inputs=gr.Image(type="pil", label="Choose a sample image"),
outputs=gr.Textbox(label="Recognized Serial Text"),
examples=[[img] for img in example_images], # list of lists required
title="Two-Stage OCR Network for Aero Engine Blades Serial Number",
description="Choose only predefined AEB images. The model will detect Serial text regions and recognize their contents."
)
if __name__ == "__main__":
demo.launch()