Josebert commited on
Commit
0e82584
·
verified ·
1 Parent(s): fd3de67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -51
app.py CHANGED
@@ -1,67 +1,87 @@
1
  import gradio as gr
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
- from PIL import Image
4
  import torch
5
- import traceback
 
 
6
 
7
- def load_model():
8
- """Load the TrOCR model and processor."""
9
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
10
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
11
- if torch.cuda.is_available():
12
- model = model.to("cuda")
13
- return processor, model
14
 
15
- def preprocess_image(image):
16
- """Preprocess the input image."""
17
- # Convert to RGB if needed
18
- if image.mode != "RGB":
19
- image = image.convert("RGB")
20
-
21
- # Resize if image is too large
22
- max_size = 1000
23
- if max(image.size) > max_size:
24
- ratio = max_size / max(image.size)
25
- new_size = tuple(int(dim * ratio) for dim in image.size)
26
- image = image.resize(new_size, Image.LANCZOS)
27
-
28
- return image
29
 
30
- def extract_text_from_image(image):
31
- """Extract text from an uploaded image using Hugging Face TrOCR model."""
32
  try:
33
- if image is None:
34
- return "Error: No image provided"
35
-
36
- # Load model and processor
37
- processor, model = load_model()
38
-
39
- # Preprocess image
40
- image = preprocess_image(image)
41
 
42
- # Extract text
43
  pixel_values = processor(image, return_tensors="pt").pixel_values
44
  if torch.cuda.is_available():
45
- pixel_values = pixel_values.to("cuda")
46
-
47
- generated_ids = model.generate(pixel_values)
48
- extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
 
50
- return extracted_text.strip()
 
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  except Exception as e:
53
- error_msg = f"Error processing image: {str(e)}\n{traceback.format_exc()}"
54
- return error_msg
55
 
56
- # Create Gradio Interface
57
- interface = gr.Interface(
58
- fn=extract_text_from_image,
59
- inputs=gr.Image(type="pil"),
60
- outputs=gr.Textbox(label="Extracted Text"),
61
- title="OCR Text Extractor",
62
- description="Upload an image to extract text using Hugging Face's TrOCR model.",
63
- examples=["sample1.jpg", "sample2.jpg"] # Add example images if you have them
 
 
 
 
 
 
 
64
  )
65
 
66
  if __name__ == "__main__":
67
- interface.launch(share=True)
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from PIL import Image
4
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
5
+ import logging
6
 
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
 
 
 
 
10
 
11
+ # Initialize TrOCR model and processor
12
+ try:
13
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
14
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
15
+ if torch.cuda.is_available():
16
+ model.to('cuda')
17
+ except Exception as e:
18
+ logger.error(f"Error loading model: {e}")
19
+ raise
 
 
 
 
 
20
 
21
+ def process_image(image):
22
+ """Process image and extract text using TrOCR"""
23
  try:
24
+ # Convert to RGB if needed
25
+ if image.mode != 'RGB':
26
+ image = image.convert('RGB')
 
 
 
 
 
27
 
28
+ # Prepare image for model
29
  pixel_values = processor(image, return_tensors="pt").pixel_values
30
  if torch.cuda.is_available():
31
+ pixel_values = pixel_values.to('cuda')
 
 
 
32
 
33
+ # Generate text
34
+ generated_ids = model.generate(pixel_values, max_length=128)
35
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
 
37
+ return generated_text.strip()
38
+ except Exception as e:
39
+ logger.error(f"Error processing image: {e}")
40
+ return f"Error processing image: {str(e)}"
41
+
42
+ def analyze_image(input_image):
43
+ """Main function to handle image analysis"""
44
+ if input_image is None:
45
+ return "Please upload an image."
46
+
47
+ try:
48
+ # Open and process image
49
+ image = Image.open(input_image)
50
+
51
+ # Extract text
52
+ extracted_text = process_image(image)
53
+
54
+ # Format response
55
+ response = f"""📝 Extracted Text:
56
+ {'-' * 40}
57
+ {extracted_text}
58
+ {'-' * 40}
59
+
60
+ 📊 Statistics:
61
+ • Characters: {len(extracted_text)}
62
+ • Words: {len(extracted_text.split())}
63
+ """
64
+ return response
65
  except Exception as e:
66
+ logger.error(f"Error in analysis: {e}")
67
+ return f"Error analyzing image: {str(e)}"
68
 
69
+ # Create Gradio interface
70
+ demo = gr.Interface(
71
+ fn=analyze_image,
72
+ inputs=gr.Image(type="filepath", label="Upload Image"),
73
+ outputs=gr.Textbox(label="Extracted Text", lines=10),
74
+ title="📷 Smart OCR Text Extractor",
75
+ description="""
76
+ Extract text from images using Microsoft's TrOCR model.
77
+ Supports handwritten and printed text.
78
+ """,
79
+ theme=gr.themes.Soft(),
80
+ examples=[
81
+ ["example1.jpg"],
82
+ ["example2.png"]
83
+ ]
84
  )
85
 
86
  if __name__ == "__main__":
87
+ demo.launch()