Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| from hezar.models import Model | |
| from hezar.utils import load_image, draw_boxes | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| # Load models on CPU (Hugging Face Spaces default) | |
| craft_model = Model.load("hezarai/CRAFT", device="cpu") | |
| processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') | |
| trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') | |
| def recognize_handwritten_text(image): | |
| try: | |
| # Ensure image is a PIL image and convert to a compatible format | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(np.array(image)).convert("RGB") | |
| # Save the uploaded image to a temporary file in JPEG format | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file: | |
| image.save(tmp_file.name, format="JPEG") | |
| tmp_path = tmp_file.name | |
| # Load image with hezar utils using file path | |
| processed_image = load_image(tmp_path) | |
| # Ensure processed_image is in a compatible format (convert to NumPy if needed) | |
| if not isinstance(processed_image, np.ndarray): | |
| processed_image = np.array(Image.open(tmp_path)) | |
| # Detect text regions with CRAFT | |
| outputs = craft_model.predict(processed_image) | |
| if not outputs or "boxes" not in outputs[0]: | |
| return Image.fromarray(processed_image), "No text detected" | |
| boxes = outputs[0]["boxes"] | |
| print(f"Debug: Boxes structure = {boxes}") # Log the exact structure | |
| pil_image = Image.fromarray(processed_image) | |
| texts = [] | |
| # Handle box format (assuming [x, y, width, height] or [[x1, y1], [x2, y2]]) | |
| for box in boxes: | |
| if len(box) == 4: # [x, y, width, height] | |
| x, y, width, height = box | |
| x_min, y_min = x, y | |
| x_max, y_max = x + width, y + height | |
| elif len(box) == 2 and all(len(p) == 2 for p in box): # [[x1, y1], [x2, y2]] | |
| x1, y1 = box[0] | |
| x2, y2 = box[1] | |
| x_min, y_min = min(x1, x2), min(y1, y2) | |
| x_max, y_max = max(x1, x2), max(y1, y2) | |
| else: | |
| print(f"Debug: Skipping invalid box {box}") # Log invalid boxes | |
| continue | |
| crop = pil_image.crop((x_min, y_min, x_max, y_max)) | |
| pixel_values = processor(images=crop, return_tensors="pt").pixel_values | |
| generated_ids = trocr_model.generate(pixel_values) | |
| text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| texts.append(text) | |
| # Draw boxes on the image | |
| result_image = draw_boxes(processed_image, boxes) | |
| result_pil = Image.fromarray(result_image) | |
| # Join recognized texts | |
| text_data = " ".join(texts) if texts else "No text recognized" | |
| return result_pil, f"Recognized text: {text_data}" | |
| except Exception as e: | |
| return Image.fromarray(np.array(image)), f"Error: {str(e)}" | |
| finally: | |
| # Clean up temporary file | |
| if 'tmp_path' in locals(): | |
| os.unlink(tmp_path) | |
| # Create Gradio interface | |
| interface = gr.Interface( | |
| fn=recognize_handwritten_text, | |
| inputs=gr.Image(type="pil", label="Upload any image format"), | |
| outputs=[gr.Image(type="pil", label="Detected Text Image"), gr.Text(label="Recognized Text")], | |
| title="Handwritten Text Detection and Recognition", | |
| description="Upload an image in any format (JPEG, PNG, BMP, etc.) to detect and recognize handwritten text." | |
| ) | |
| # Launch the app | |
| interface.launch() |