Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize TrOCR model and processor | |
| try: | |
| processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') | |
| model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') | |
| if torch.cuda.is_available(): | |
| model.to('cuda') | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise | |
| def process_image(image): | |
| """Process image and extract text using TrOCR""" | |
| try: | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Prepare image for model | |
| pixel_values = processor(image, return_tensors="pt").pixel_values | |
| if torch.cuda.is_available(): | |
| pixel_values = pixel_values.to('cuda') | |
| # Generate text | |
| generated_ids = model.generate(pixel_values, max_length=128) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return generated_text.strip() | |
| except Exception as e: | |
| logger.error(f"Error processing image: {e}") | |
| return f"Error processing image: {str(e)}" | |
| def analyze_image(input_image): | |
| """Main function to handle image analysis""" | |
| if input_image is None: | |
| return "Please upload an image." | |
| try: | |
| # Open and process image | |
| image = Image.open(input_image) | |
| # Extract text | |
| extracted_text = process_image(image) | |
| # Format response | |
| response = f"""π Extracted Text: | |
| {'-' * 40} | |
| {extracted_text} | |
| {'-' * 40} | |
| π Statistics: | |
| β’ Characters: {len(extracted_text)} | |
| β’ Words: {len(extracted_text.split())} | |
| """ | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error in analysis: {e}") | |
| return f"Error analyzing image: {str(e)}" | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=analyze_image, | |
| inputs=gr.Image(type="filepath", label="Upload Image"), | |
| outputs=gr.Textbox(label="Extracted Text", lines=10), | |
| title="π· Smart OCR Text Extractor", | |
| description=""" | |
| Extract text from images using Microsoft's TrOCR model. | |
| Supports handwritten and printed text. | |
| """, | |
| theme=gr.themes.Soft(), | |
| examples=[ | |
| ["example1.jpg"], | |
| ["example2.png"] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |