Zeeshan24 commited on
Commit
4e00f7b
·
verified ·
1 Parent(s): c2bf072

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -20
app.py CHANGED
@@ -2,38 +2,38 @@ import streamlit as st
2
  import cv2
3
  import numpy as np
4
  from PIL import Image
 
5
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline
6
  import re
7
 
8
- # Install TrOCR Model for Handwritten OCR
9
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
10
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
11
 
12
- # Load pre-trained QA model
13
  qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
14
 
15
- # Preprocess image for better OCR performance
16
  def preprocess_image(image_file):
17
- # Convert image to OpenCV format (numpy array)
18
  image = np.array(Image.open(image_file).convert("RGB"))
19
-
20
- # Preprocessing: Grayscale, blur, threshold (to clean up image)
21
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # Convert to grayscale
22
- blurred = cv2.GaussianBlur(gray, (5, 5), 0) # Remove noise
23
- thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] # Enhance contrast
24
-
25
- # Convert back to RGB (3-channel) format for compatibility with TrOCR
26
  preprocessed_image = cv2.cvtColor(thresh, cv2.COLOR_GRAY2RGB)
27
- return Image.fromarray(preprocessed_image) # Convert back to PIL format
28
 
29
- # Extract text using TrOCR
30
- def extract_text_from_handwriting(image):
 
 
 
 
31
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
32
  generated_ids = model.generate(pixel_values)
33
  extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
  return extracted_text
35
 
36
- # Extract student info
37
  def extract_student_info(text):
38
  name = re.search(r"NAME\s*=\s*([\w\s]+)", text, re.IGNORECASE)
39
  roll_no = re.search(r"Roll\s*NO\s*=\s*(\d+)", text, re.IGNORECASE)
@@ -41,7 +41,7 @@ def extract_student_info(text):
41
  roll_number = roll_no.group(1).strip() if roll_no else "Unknown"
42
  return student_name, roll_number
43
 
44
- # Extract questions
45
  def extract_questions_from_text(text):
46
  questions = re.findall(r'(?:[^\n]*\?)', text)
47
  return questions
@@ -59,14 +59,20 @@ st.write("Upload an image or handwritten file to process.")
59
  uploaded_image = st.file_uploader("Upload Handwritten Image", type=["png", "jpg", "jpeg"])
60
 
61
  if uploaded_image:
62
- # Preprocess the image
63
  st.image(uploaded_image, caption="Original Image", use_container_width=True)
 
 
64
  preprocessed_image = preprocess_image(uploaded_image)
65
  st.image(preprocessed_image, caption="Preprocessed Image", use_container_width=True)
66
 
67
- # Extract text using TrOCR
68
- extracted_text = extract_text_from_handwriting(preprocessed_image)
69
- st.subheader("Extracted Text")
 
 
 
 
 
70
  st.text(extracted_text)
71
 
72
  # Extract student info
 
2
  import cv2
3
  import numpy as np
4
  from PIL import Image
5
+ import pytesseract
6
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline
7
  import re
8
 
9
+ # Load TrOCR model for handwriting recognition
10
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
11
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
12
 
13
+ # Load pre-trained QA model for grading
14
  qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
15
 
16
+ # Function to preprocess the image
17
  def preprocess_image(image_file):
 
18
  image = np.array(Image.open(image_file).convert("RGB"))
19
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
20
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
21
+ thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
 
 
 
 
22
  preprocessed_image = cv2.cvtColor(thresh, cv2.COLOR_GRAY2RGB)
23
+ return Image.fromarray(preprocessed_image)
24
 
25
+ # Function to extract text using Tesseract OCR
26
+ def extract_text_with_tesseract(image):
27
+ return pytesseract.image_to_string(image)
28
+
29
+ # Function to extract text using TrOCR
30
+ def extract_text_with_trocr(image):
31
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
32
  generated_ids = model.generate(pixel_values)
33
  extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
  return extracted_text
35
 
36
+ # Extract student name and roll number
37
  def extract_student_info(text):
38
  name = re.search(r"NAME\s*=\s*([\w\s]+)", text, re.IGNORECASE)
39
  roll_no = re.search(r"Roll\s*NO\s*=\s*(\d+)", text, re.IGNORECASE)
 
41
  roll_number = roll_no.group(1).strip() if roll_no else "Unknown"
42
  return student_name, roll_number
43
 
44
+ # Extract questions from the text
45
  def extract_questions_from_text(text):
46
  questions = re.findall(r'(?:[^\n]*\?)', text)
47
  return questions
 
59
  uploaded_image = st.file_uploader("Upload Handwritten Image", type=["png", "jpg", "jpeg"])
60
 
61
  if uploaded_image:
 
62
  st.image(uploaded_image, caption="Original Image", use_container_width=True)
63
+
64
+ # Preprocess the image
65
  preprocessed_image = preprocess_image(uploaded_image)
66
  st.image(preprocessed_image, caption="Preprocessed Image", use_container_width=True)
67
 
68
+ # Attempt text extraction with Tesseract
69
+ st.subheader("Extracted Text:")
70
+ tesseract_text = extract_text_with_tesseract(preprocessed_image)
71
+ if len(tesseract_text.strip()) > 10:
72
+ extracted_text = tesseract_text # Use Tesseract output if it seems valid
73
+ else:
74
+ extracted_text = extract_text_with_trocr(preprocessed_image) # Use TrOCR fallback
75
+
76
  st.text(extracted_text)
77
 
78
  # Extract student info