Zeeshan24 commited on
Commit
421fc43
·
verified ·
1 Parent(s): a08c821

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -10
app.py CHANGED
@@ -1,24 +1,33 @@
1
  import streamlit as st
 
 
2
  from PIL import Image
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline
4
  import re
5
 
6
- # Load TrOCR Model for Handwritten OCR
7
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
8
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
9
 
10
  # Load pre-trained QA model
11
  qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
12
 
13
- # Function to extract text using TrOCR
14
- def extract_text_from_handwriting(image_file):
15
- image = Image.open(image_file).convert("RGB")
 
 
 
 
 
 
 
16
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
17
  generated_ids = model.generate(pixel_values)
18
  extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
19
  return extracted_text
20
 
21
- # Extract student name and roll number
22
  def extract_student_info(text):
23
  name = re.search(r"NAME\s*=\s*([\w\s]+)", text, re.IGNORECASE)
24
  roll_no = re.search(r"Roll\s*NO\s*=\s*(\d+)", text, re.IGNORECASE)
@@ -26,12 +35,12 @@ def extract_student_info(text):
26
  roll_number = roll_no.group(1).strip() if roll_no else "Unknown"
27
  return student_name, roll_number
28
 
29
- # Extract questions from the text
30
  def extract_questions_from_text(text):
31
- questions = re.findall(r'(?:[^\n]*\?)', text) # Extract sentences ending with "?"
32
  return questions
33
 
34
- # Grading function using QA model
35
  def grade_answer(question, context):
36
  result = qa_pipeline(question=question, context=context)
37
  return result['score'], "Correct" if result['score'] > 0.5 else "Incorrect"
@@ -44,10 +53,13 @@ st.write("Upload an image or handwritten file to process.")
44
  uploaded_image = st.file_uploader("Upload Handwritten Image", type=["png", "jpg", "jpeg"])
45
 
46
  if uploaded_image:
47
- st.image(uploaded_image, caption="Uploaded Handwritten File", use_container_width=True)
 
 
 
48
 
49
  # Extract text using TrOCR
50
- extracted_text = extract_text_from_handwriting(uploaded_image)
51
  st.subheader("Extracted Text")
52
  st.text(extracted_text)
53
 
 
1
  import streamlit as st
2
+ import cv2
3
+ import numpy as np
4
  from PIL import Image
5
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline
6
  import re
7
 
8
+ # Install TrOCR Model for Handwritten OCR
9
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
10
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
11
 
12
  # Load pre-trained QA model
13
  qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
14
 
15
+ # Preprocess image for better OCR performance
16
+ def preprocess_image(image_file):
17
+ image = np.array(Image.open(image_file).convert("RGB"))
18
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # Convert to grayscale
19
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0) # Remove noise
20
+ thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] # Increase contrast
21
+ return Image.fromarray(thresh) # Convert back to PIL format
22
+
23
+ # Extract text using TrOCR
24
+ def extract_text_from_handwriting(image):
25
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
26
  generated_ids = model.generate(pixel_values)
27
  extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
28
  return extracted_text
29
 
30
+ # Extract student info
31
  def extract_student_info(text):
32
  name = re.search(r"NAME\s*=\s*([\w\s]+)", text, re.IGNORECASE)
33
  roll_no = re.search(r"Roll\s*NO\s*=\s*(\d+)", text, re.IGNORECASE)
 
35
  roll_number = roll_no.group(1).strip() if roll_no else "Unknown"
36
  return student_name, roll_number
37
 
38
+ # Extract questions
39
  def extract_questions_from_text(text):
40
+ questions = re.findall(r'(?:[^\n]*\?)', text)
41
  return questions
42
 
43
+ # Grade answers
44
  def grade_answer(question, context):
45
  result = qa_pipeline(question=question, context=context)
46
  return result['score'], "Correct" if result['score'] > 0.5 else "Incorrect"
 
53
  uploaded_image = st.file_uploader("Upload Handwritten Image", type=["png", "jpg", "jpeg"])
54
 
55
  if uploaded_image:
56
+ # Preprocess the image
57
+ st.image(uploaded_image, caption="Original Image", use_container_width=True)
58
+ preprocessed_image = preprocess_image(uploaded_image)
59
+ st.image(preprocessed_image, caption="Preprocessed Image", use_container_width=True)
60
 
61
  # Extract text using TrOCR
62
+ extracted_text = extract_text_from_handwriting(preprocessed_image)
63
  st.subheader("Extracted Text")
64
  st.text(extracted_text)
65