Josebert commited on
Commit
a5ad694
·
verified ·
1 Parent(s): 9d4f5b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -12
app.py CHANGED
@@ -1,20 +1,57 @@
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
- import requests
5
- from io import BytesIO
6
 
7
- # Load TrOCR model and processor
8
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
9
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def extract_text_from_image(image):
12
  """Extract text from an uploaded image using Hugging Face TrOCR model."""
13
- image = image.convert("RGB")
14
- pixel_values = processor(image, return_tensors="pt").pixel_values
15
- generated_ids = model.generate(pixel_values)
16
- extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
17
- return extracted_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Create Gradio Interface
20
  interface = gr.Interface(
@@ -22,8 +59,9 @@ interface = gr.Interface(
22
  inputs=gr.Image(type="pil"),
23
  outputs=gr.Textbox(label="Extracted Text"),
24
  title="OCR Text Extractor",
25
- description="Upload an image to extract text using Hugging Face's TrOCR model."
 
26
  )
27
 
28
  if __name__ == "__main__":
29
- interface.launch(share=True)
 
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
+ import torch
5
+ import traceback
6
 
7
+ def load_model():
8
+ """Load the TrOCR model and processor."""
9
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
10
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
11
+ if torch.cuda.is_available():
12
+ model = model.to("cuda")
13
+ return processor, model
14
+
15
+ def preprocess_image(image):
16
+ """Preprocess the input image."""
17
+ # Convert to RGB if needed
18
+ if image.mode != "RGB":
19
+ image = image.convert("RGB")
20
+
21
+ # Resize if image is too large
22
+ max_size = 1000
23
+ if max(image.size) > max_size:
24
+ ratio = max_size / max(image.size)
25
+ new_size = tuple(int(dim * ratio) for dim in image.size)
26
+ image = image.resize(new_size, Image.LANCZOS)
27
+
28
+ return image
29
 
30
  def extract_text_from_image(image):
31
  """Extract text from an uploaded image using Hugging Face TrOCR model."""
32
+ try:
33
+ if image is None:
34
+ return "Error: No image provided"
35
+
36
+ # Load model and processor
37
+ processor, model = load_model()
38
+
39
+ # Preprocess image
40
+ image = preprocess_image(image)
41
+
42
+ # Extract text
43
+ pixel_values = processor(image, return_tensors="pt").pixel_values
44
+ if torch.cuda.is_available():
45
+ pixel_values = pixel_values.to("cuda")
46
+
47
+ generated_ids = model.generate(pixel_values)
48
+ extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
+
50
+ return extracted_text.strip()
51
+
52
+ except Exception as e:
53
+ error_msg = f"Error processing image: {str(e)}\n{traceback.format_exc()}"
54
+ return error_msg
55
 
56
  # Create Gradio Interface
57
  interface = gr.Interface(
 
59
  inputs=gr.Image(type="pil"),
60
  outputs=gr.Textbox(label="Extracted Text"),
61
  title="OCR Text Extractor",
62
+ description="Upload an image to extract text using Hugging Face's TrOCR model.",
63
+ examples=["sample1.jpg", "sample2.jpg"] # Add example images if you have them
64
  )
65
 
66
  if __name__ == "__main__":
67
+ interface.launch(share=True)