File size: 2,560 Bytes
2c2d398
a5ad694
0e82584
 
 
c0d9719
0e82584
 
 
a5ad694
0e82584
 
 
 
 
 
 
 
 
c0d9719
0e82584
 
a5ad694
0e82584
 
 
a5ad694
0e82584
a5ad694
 
0e82584
a5ad694
0e82584
 
 
a5ad694
0e82584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5ad694
0e82584
 
c0d9719
0e82584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a52e06
fb37c01
2c2d398
0e82584
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import torch
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize TrOCR model and processor
try:
    processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
    model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
    if torch.cuda.is_available():
        model.to('cuda')
except Exception as e:
    logger.error(f"Error loading model: {e}")
    raise

def process_image(image):
    """Process image and extract text using TrOCR"""
    try:
        # Convert to RGB if needed
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Prepare image for model
        pixel_values = processor(image, return_tensors="pt").pixel_values
        if torch.cuda.is_available():
            pixel_values = pixel_values.to('cuda')
        
        # Generate text
        generated_ids = model.generate(pixel_values, max_length=128)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        return generated_text.strip()
    except Exception as e:
        logger.error(f"Error processing image: {e}")
        return f"Error processing image: {str(e)}"

def analyze_image(input_image):
    """Main function to handle image analysis"""
    if input_image is None:
        return "Please upload an image."
    
    try:
        # Open and process image
        image = Image.open(input_image)
        
        # Extract text
        extracted_text = process_image(image)
        
        # Format response
        response = f"""πŸ“ Extracted Text:
{'-' * 40}
{extracted_text}
{'-' * 40}

πŸ“Š Statistics:
β€’ Characters: {len(extracted_text)}
β€’ Words: {len(extracted_text.split())}
"""
        return response
    except Exception as e:
        logger.error(f"Error in analysis: {e}")
        return f"Error analyzing image: {str(e)}"

# Create Gradio interface
demo = gr.Interface(
    fn=analyze_image,
    inputs=gr.Image(type="filepath", label="Upload Image"),
    outputs=gr.Textbox(label="Extracted Text", lines=10),
    title="πŸ“· Smart OCR Text Extractor",
    description="""
    Extract text from images using Microsoft's TrOCR model.
    Supports handwritten and printed text.
    """,
    theme=gr.themes.Soft(),
    examples=[
        ["example1.jpg"],
        ["example2.png"]
    ]
)

if __name__ == "__main__":
    demo.launch()