import gradio as gr from STD_detect import OBBPredictor from STR_recognize import TextRecognizer import os # Initialize models STD_model_path = "pretrained_model/STD.pt" STR_ckpt_path = "pretrained_model/STR.ckpt" detector = OBBPredictor(STD_model_path) recognizer = TextRecognizer(STR_ckpt_path, device='cpu') # or 'cuda' if on GPU # ==== OCR pipeline function ==== def run_pipeline(image): crops = detector.predict(image) recognized_texts = [recognizer.recognize(crop) for crop in crops] final_output = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(recognized_texts)]) return final_output if recognized_texts else "No text detected." # ==== Get sample image paths ==== example_images = [f"samples/{f}" for f in os.listdir("samples") if f.lower().endswith(('.png', '.jpg', '.jpeg'))] # ==== Gradio app with ONLY sample images ==== demo = gr.Interface( fn=run_pipeline, inputs=gr.Image(type="pil", label="Choose a sample image"), outputs=gr.Textbox(label="Recognized Serial Text"), examples=[[img] for img in example_images], # list of lists required title="Two-Stage OCR Network for Aero Engine Blades Serial Number", description="Choose only predefined AEB images. The model will detect Serial text regions and recognize their contents." ) if __name__ == "__main__": demo.launch()