OMR_checker / app.py
yuyutsu07's picture
Update app.py
f8caafa verified
# app.py
import gradio as gr
import os
import google.generativeai as genai
from PIL import Image
import numpy as np
import cv2
import io
import json
import re
import imutils # For easier contour sorting and perspective transform helpers
from skimage.metrics import structural_similarity as ssim # For potential template matching (optional enhancement)
import math
import time
# --- Configuration & Constants ---
# IMPORTANT: Add your GOOGLE_API_KEY as a Secret in your Hugging Face Space settings
try:
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
raise ValueError("⚠️ GOOGLE_API_KEY not found in environment variables/secrets.")
genai.configure(api_key=GOOGLE_API_KEY)
gemini_model = genai.GenerativeModel('gemini-1.5-flash-latest')
print("βœ… Gemini configured successfully.")
except Exception as e:
print(f"❌ Error configuring Gemini: {e}")
gemini_model = None
# OMR Sheet Configuration (CRITICAL TUNING PARAMETERS)
OMR_TOTAL_QUESTIONS = 180
OMR_OPTIONS_PER_QUESTION = 4
OMR_BUBBLE_RADIUS_APPROX = 7 # Approximate radius for filtering - NEEDS TUNING
OMR_ASPECT_RATIO_TOLERANCE = 0.3 # How much deviation from a perfect circle (width/height)
OMR_SOLIDITY_THRESHOLD = 0.8 # How 'solid' the contour is (area / convex hull area) - Helps filter non-circles
OMR_MARK_THRESHOLD_RATIO = 0.4 # Ratio of non-zero pixels within bubble to consider it 'marked' - NEEDS TUNING
OMR_COLUMNS = 4 # Number of question columns in the grid
# --- Helper Functions ---
def pil_to_cv2(pil_image):
"""Converts a PIL Image to an OpenCV image (BGR)."""
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
def cv2_to_pil(cv2_image):
"""Converts an OpenCV image (BGR) to a PIL Image."""
return Image.fromarray(cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB))
def order_points(pts):
"""Orders 4 points: top-left, top-right, bottom-right, bottom-left."""
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
return rect
def four_point_transform(image, pts):
"""Applies perspective transform using 4 ordered points."""
rect = order_points(pts)
(tl, tr, br, bl) = rect
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
maxWidth = max(int(widthA), int(widthB))
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
maxHeight = max(int(heightA), int(heightB))
dst = np.array([
[0, 0],
[maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(rect, dst)
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
return warped
# --- Core Logic Functions ---
def extract_answer_key_gemini(answer_key_pil_image):
"""Uses Gemini to extract the answer key from an image."""
if not gemini_model:
raise gr.Error("Gemini model not configured. Cannot process answer key image.")
if answer_key_pil_image is None:
raise gr.Error("Answer Key image is missing.")
print("πŸ“„ Processing Answer Key with Gemini...")
start_time = time.time()
img_byte_arr = io.BytesIO()
answer_key_pil_image.save(img_byte_arr, format='PNG') # Use PNG for potentially better quality
img_bytes = img_byte_arr.getvalue()
prompt = f"""
Analyze the provided image of a NEET answer key. The answers are numerical options (1, 2, 3, 4).
Extract the correct numerical option for each question number from 1 to {OMR_TOTAL_QUESTIONS}.
Output MUST be a valid JSON object mapping the question number (as an integer key) to the correct numerical option (as an integer value, 1, 2, 3, or 4).
Example format: {{1: 2, 2: 4, 3: 1, ..., {OMR_TOTAL_QUESTIONS}: 3}}
Ensure all {OMR_TOTAL_QUESTIONS} questions are present in the JSON. If a question's answer is ambiguous or unreadable from the image, use the integer 0 as the value for that question number. Do not omit any question numbers from 1 to {OMR_TOTAL_QUESTIONS}.
"""
try:
response = gemini_model.generate_content([prompt, {'mime_type': 'image/png', 'data': img_bytes}])
# Try to find JSON block even with potential markdown formatting
match = re.search(r"```json\s*(\{.*?\})\s*```", response.text, re.DOTALL)
if match:
cleaned_response = match.group(1)
else:
# Fallback if no markdown block found, assume plain JSON
cleaned_response = response.text.strip()
# Attempt to parse the JSON
parsed_json = json.loads(cleaned_response)
# Convert string keys/values to integers and validate
answer_key = {}
valid_options = {1, 2, 3, 4, 0} # Include 0 for ambiguous case
expected_keys_found = 0
for i in range(1, OMR_TOTAL_QUESTIONS + 1):
key_str = str(i)
key_int = i
value = None
value_raw = parsed_json.get(key_str, parsed_json.get(key_int)) # Check for str then int key
if value_raw is not None:
expected_keys_found += 1
try:
value_int = int(value_raw)
if value_int in valid_options:
answer_key[key_int] = value_int
else:
print(f"⚠️ Invalid option value '{value_raw}' for Q{key_int}. Using 0.")
answer_key[key_int] = 0 # Mark as ambiguous/error
except (ValueError, TypeError):
print(f"⚠️ Non-integer value '{value_raw}' for Q{key_int}. Using 0.")
answer_key[key_int] = 0 # Mark as ambiguous/error
else:
print(f"⚠️ Missing answer for Q{key_int} in Gemini response. Using 0.")
answer_key[key_int] = 0 # Mark as missing/ambiguous
# Final check for completeness
if expected_keys_found != OMR_TOTAL_QUESTIONS:
print(f"πŸ”₯ Critical Warning: Gemini response contained {expected_keys_found} entries, expected {OMR_TOTAL_QUESTIONS}. Missing answers defaulted to 0.")
# Ensure all keys 1-180 exist, even if defaulted to 0
for i in range(1, OMR_TOTAL_QUESTIONS + 1):
if i not in answer_key: answer_key[i] = 0
print(f"βœ… Answer Key extraction finished in {time.time() - start_time:.2f}s. Processed {OMR_TOTAL_QUESTIONS} questions.")
return answer_key
except json.JSONDecodeError as json_err:
print(f"❌ Error parsing Gemini JSON response: {json_err}")
print(f"Raw Gemini Response Text:\n{response.text}")
raise gr.Error("Failed to parse answer key from Gemini response. Ensure image is clear and format is standard.")
except Exception as e:
print(f"❌ Error calling Gemini API or processing response: {e}")
raise gr.Error(f"Error processing answer key with Gemini: {e}")
def extract_omr_answers_cv(omr_sheet_pil_image):
"""
Processes the OMR sheet image using OpenCV to find marked bubbles.
Returns:
dict: Mapping question number (int) to marked option (int 1-4) or None.
np.ndarray: Annotated OpenCV image for display.
"""
if omr_sheet_pil_image is None:
raise gr.Error("OMR Sheet image is missing.")
print("πŸ“„ Processing OMR Sheet with OpenCV...")
start_time = time.time()
img_cv = pil_to_cv2(omr_sheet_pil_image)
img_height, img_width = img_cv.shape[:2]
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# 1. Perspective Correction (Attempt)
print(" - Applying perspective correction...")
# Find contours on the blurred image for the outline
cnts = cv2.findContours(cv2.Canny(blurred, 50, 150), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = imutils.grab_contours(cnts)
docCnt = None
if len(cnts) > 0:
cnts = sorted(cnts, key=cv2.contourArea, reverse=True)
for c in cnts:
peri = cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, 0.02 * peri, True)
if len(approx) == 4 and cv2.contourArea(c) > (img_width * img_height * 0.1): # Ensure contour is large enough
docCnt = approx
break
warped_gray = gray
warped_color = img_cv
if docCnt is not None:
try:
warped_color = four_point_transform(img_cv, docCnt.reshape(4, 2))
warped_gray = cv2.cvtColor(warped_color, cv2.COLOR_BGR2GRAY) # Use warped color for gray
print(" - Perspective warp applied.")
except Exception as e:
print(f" - Warning: Perspective warp failed ({e}). Using original image.")
warped_gray = gray # Revert to original gray if warp fails
warped_color = img_cv
else:
print(" - Warning: Document outline (4 corners) not found. Using original image.")
# 2. Thresholding
print(" - Applying adaptive thresholding...")
thresh = cv2.adaptiveThreshold(warped_gray, 255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 19, 5) # Tunable params
# 3. Find Bubble Contours
print(" - Finding bubble contours...")
cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = imutils.grab_contours(cnts)
questionCnts = []
bubble_count = 0
# Dynamic radius calculation attempt based on warped image size (heuristic)
warped_h, warped_w = warped_gray.shape[:2]
# Estimate based on typical OMR layouts (might need adjustment)
approx_bubble_area_min = int(math.pi * (warped_w * 0.008)**2) # Lower bound area guess
approx_bubble_area_max = int(math.pi * (warped_w * 0.02)**2) # Upper bound area guess
print(f" - Estimated bubble area range: {approx_bubble_area_min}-{approx_bubble_area_max}")
for c in cnts:
(x, y, w, h) = cv2.boundingRect(c)
ar = w / float(h) if h > 0 else 0
area = cv2.contourArea(c)
hull = cv2.convexHull(c)
hull_area = cv2.contourArea(hull) if hull is not None else 0
solidity = area / float(hull_area) if hull_area > 0 else 0
# Adjust filtering based on dynamic area estimates
if approx_bubble_area_min < area < approx_bubble_area_max and \
abs(ar - 1.0) < OMR_ASPECT_RATIO_TOLERANCE and \
solidity > OMR_SOLIDITY_THRESHOLD:
questionCnts.append(c)
bubble_count += 1
print(f" - Found {bubble_count} potential bubble contours after filtering.")
expected_bubbles = OMR_TOTAL_QUESTIONS * OMR_OPTIONS_PER_QUESTION
if bubble_count < (expected_bubbles * 0.7): # Stricter check now
print(f"πŸ”₯ Critical Warning: Found significantly fewer bubbles ({bubble_count}) than expected ({expected_bubbles}). Grid detection might fail. Check image quality/parameters.")
# Don't raise error yet, attempt grouping anyway
if not questionCnts:
raise gr.Error("No potential bubble contours found after filtering. Check OMR image quality, lighting, and processing parameters.")
# 4. Sort Contours and Group into Questions
print(" - Sorting and grouping contours...")
# Sort by columns first, then rows (typical OMR reading order)
num_rows_per_col = OMR_TOTAL_QUESTIONS // OMR_COLUMNS
contours_per_col = num_rows_per_col * OMR_OPTIONS_PER_QUESTION
# Estimate column breaks (heuristic based on width)
col_width = warped_w / OMR_COLUMNS
def get_col(contour):
x, _, _, _ = cv2.boundingRect(contour)
return int(x // col_width)
# Sort primarily by column, then by Y coordinate within column, then X within row (for options)
questionCnts = sorted(questionCnts, key=lambda c: (get_col(c), cv2.boundingRect(c)[1], cv2.boundingRect(c)[0]))
user_answers = {}
annotated_img = warped_color.copy()
processed_bubble_count = 0
for q_index in range(OMR_TOTAL_QUESTIONS):
current_q_num = q_index + 1
start_index = q_index * OMR_OPTIONS_PER_QUESTION
end_index = start_index + OMR_OPTIONS_PER_QUESTION
if end_index > len(questionCnts):
print(f"⚠️ Warning: Not enough contours remaining for Q{current_q_num}. Stopping OMR parse early.")
break # Stop if we run out of contours prematurely
# Assume the next N contours belong to this question based on sorting
current_q_contours = questionCnts[start_index:end_index]
# Sanity check: Ensure these contours are reasonably close horizontally (same row)
y_coords = [cv2.boundingRect(c)[1] for c in current_q_contours]
if max(y_coords) - min(y_coords) > (warped_h * 0.05): # Heuristic vertical spread check
print(f"⚠️ Warning: Bubbles for Q{current_q_num} seem vertically misaligned. Skipping.")
user_answers[current_q_num] = None
continue
# Re-sort this group left-to-right just to be sure
current_q_contours = sorted(current_q_contours, key=lambda c: cv2.boundingRect(c)[0])
marked_option = None
max_filled_ratio = -1
best_option_index = -1
multiple_marks = False
for j, c in enumerate(current_q_contours):
processed_bubble_count +=1
mask = np.zeros(thresh.shape, dtype="uint8")
cv2.drawContours(mask, [c], -1, 255, -1)
mask = cv2.bitwise_and(thresh, thresh, mask=mask)
total_pixels = cv2.countNonZero(mask)
(x, y, w, h) = cv2.boundingRect(c)
bubble_area = max(1, w * h) # Use bounding box area as approx if contourArea is weird
filled_ratio = total_pixels / float(bubble_area)
# --- Mark Detection Logic ---
if filled_ratio > OMR_MARK_THRESHOLD_RATIO:
if best_option_index == -1: # First potential mark found
max_filled_ratio = filled_ratio
best_option_index = j
else:
# Found another mark. Check if it's significantly darker or just noise
if filled_ratio > max_filled_ratio * 1.2: # If this one is much darker, prefer it
max_filled_ratio = filled_ratio
best_option_index = j
multiple_marks = False # Reset multiple mark flag if this one is clearly dominant
elif max_filled_ratio > filled_ratio * 1.2: # Previous one was much darker, ignore this
pass
else: # Both are similarly dark - likely multiple marks
multiple_marks = True
print(f"⚠️ Multiple marks detected or ambiguous fill for Q{current_q_num}")
cv2.drawContours(annotated_img, [c], -1, (0, 165, 255), 2) # Orange for ambiguous
cv2.drawContours(annotated_img, [current_q_contours[best_option_index]], -1, (0, 165, 255), 2)
# Draw blue box around all detected bubbles for debugging
cv2.rectangle(annotated_img, (x, y), (x + w, y + h), (255, 0, 0), 1)
if multiple_marks:
user_answers[current_q_num] = None # Treat multiple marks as unattempted/error
elif best_option_index != -1:
user_answers[current_q_num] = best_option_index + 1 # Options are 1-based
# Draw green circle on the single chosen marked bubble
cv2.drawContours(annotated_img, [current_q_contours[best_option_index]], -1, (0, 255, 0), 2)
else:
user_answers[current_q_num] = None # No mark found above threshold
# Add question number text (optional)
(x, y, w, h) = cv2.boundingRect(current_q_contours[0])
text_y = y - 5 if y > 15 else y + h + 15
cv2.putText(annotated_img, str(current_q_num), (max(0, x - 10), text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 0), 1)
print(f" - Processed {processed_bubble_count} bubbles across {q_index + 1} questions.")
# Fill missing questions if loop didn't reach 180 (likely due to contour shortage)
for q_final in range(1, OMR_TOTAL_QUESTIONS + 1):
if q_final not in user_answers:
user_answers[q_final] = None
print(f" - Q{q_final} was not processed (likely missing contours). Marked as unattempted.")
print(f"βœ… OMR Sheet processing finished in {time.time() - start_time:.2f}s.")
return user_answers, annotated_img
def calculate_neet_score(correct_answers, user_answers, annotated_img=None):
"""Calculates score and optionally annotates the image further."""
if not correct_answers or not user_answers:
return 0, 0, 0, 0, {}, annotated_img # Return zeros if input is missing
total_score = 0
correct_count = 0
incorrect_count = 0
unattempted_count = 0
results_details = {} # Store per-question results
for i in range(1, OMR_TOTAL_QUESTIONS + 1):
correct_opt = correct_answers.get(i) # Already validated to be 0, 1, 2, 3, or 4
user_opt = user_answers.get(i) # Can be 1, 2, 3, 4, or None
status = "Skipped"
if correct_opt == 0 or correct_opt is None: # Ambiguous/missing in key
status = "Key Error"
unattempted_count += 1
# Optionally mark on image (needs coordinates passed)
elif user_opt is None:
status = "Unattempted/Error" # Includes multiple marks or no mark found
unattempted_count += 1
elif user_opt == correct_opt:
status = "Correct"
total_score += 4
correct_count += 1
# Green circle already drawn during OMR processing
else:
status = "Incorrect"
total_score -= 1
incorrect_count += 1
# Optionally draw red mark on image (needs coordinates passed)
results_details[i] = {"user": user_opt if user_opt else "-",
"correct": correct_opt if correct_opt else "Err",
"status": status}
print(f"Score calculated: Total={total_score}, Correct={correct_count}, Incorrect={incorrect_count}, Unattempted={unattempted_count}")
return total_score, correct_count, incorrect_count, unattempted_count, results_details, annotated_img
# --- Main Gradio Processing Function ---
def process_omr_and_key(answer_key_img_pil, omr_sheet_img_pil):
if answer_key_img_pil is None or omr_sheet_img_pil is None:
return "❌ Error: Please upload both the Answer Key image and the OMR Sheet image.", None
logs = []
def log_message(message):
print(message) # Print to console/HF logs
logs.append(message)
try:
# 1. Extract Answer Key using Gemini
correct_answers = extract_answer_key_gemini(answer_key_img_pil)
if not correct_answers:
raise gr.Error("Failed to extract Answer Key.")
log_message(f"πŸ”‘ Answer Key processed. Found results for {len(correct_answers)} questions.")
# 2. Extract User Answers from OMR using OpenCV
user_answers, annotated_omr_cv = extract_omr_answers_cv(omr_sheet_img_pil)
if not user_answers:
raise gr.Error("Failed to extract answers from OMR sheet.")
log_message(f"πŸ“Š OMR Sheet processed. Found results for {len(user_answers)} questions.")
# Convert annotated CV image back to PIL for display
annotated_omr_pil = cv2_to_pil(annotated_omr_cv) if annotated_omr_cv is not None else None
# 3. Calculate Score and Get Details
score, correct, incorrect, unattempted, results_details, _ = calculate_neet_score(
correct_answers, user_answers, annotated_omr_cv
)
log_message(f"πŸ’― Score calculated: {score}")
# 4. Format Results
results_summary_md = f"""
## NEET OMR Analysis Results
**Total Score:** {score} / {OMR_TOTAL_QUESTIONS * 4}
---
**Breakdown:**
* βœ… **Correct Answers:** {correct} (+{correct * 4} marks)
* ❌ **Incorrect Answers:** {incorrect} ({incorrect * -1} marks)
* βšͺ **Unattempted / Errors:** {unattempted} (+0 marks)
---
**Debug Log:**
```
{chr(10).join(logs)}
```
---
**Disclaimer:**
*This result is based on automated analysis (Gemini & OpenCV) and may contain inaccuracies.*
*Verify results, especially if warnings appeared in logs.*
"""
return results_summary_md, annotated_omr_pil # Return summary and annotated image
except gr.Error as ge: # Catch Gradio specific errors for user feedback
print(f"❌ Gradio Error: {ge}")
return f"❌ Error: {ge}", None
except Exception as e:
print(f"❌ An unexpected error occurred: {e}")
import traceback
traceback.print_exc() # Print full traceback to logs for debugging
return f"❌ An unexpected error occurred during processing: {e}. Check logs for details.", None
# --- Gradio Interface ---
description = f"""
## Advanced NEET OMR Sheet Checker (Gemini + OpenCV) [Beta]
Upload an image of the **NEET Answer Key** (with numerical options 1-4) and an image of your filled **NEET OMR Sheet**.
**How it works:**
1. **Answer Key:** Uses Google **Gemini 1.5 Flash** to analyze the image and extract the correct options (1-4) for all {OMR_TOTAL_QUESTIONS} questions. Requires `GOOGLE_API_KEY` in Space Secrets.
2. **OMR Sheet:** Uses **OpenCV** for image processing (perspective correction, thresholding, contour detection/filtering, sorting, mark detection).
3. **Scoring:** Calculates score based on NEET pattern (+4, -1, 0).
**IMPORTANT NOTES:**
* **API Key:** Add your `GOOGLE_API_KEY` to Space **Secrets**.
* **Accuracy:** Results depend heavily on image quality and clarity. OMR reading is complex. Check logs for warnings.
* **Parameters:** OpenCV parameters in the code might need tuning for different sheets/scans.
* **Output:** Score summary and annotated OMR (green circles on detected marks).
"""
# Load example images if they exist in the Space repository
example_key_path = "answer_key_example.jpg" # Replace with your actual example filename
example_omr_path = "omr_sheet_example.png" # Replace with your actual example filename
examples = []
if os.path.exists(example_key_path) and os.path.exists(example_omr_path):
examples.append([example_key_path, example_omr_path])
iface = gr.Interface(
fn=process_omr_and_key,
inputs=[
gr.Image(type="pil", label="1. Upload Answer Key Image (Numerical Options 1-4)"),
gr.Image(type="pil", label="2. Upload OMR Sheet Image")
],
outputs=[
gr.Markdown(label="Results Summary & Log"),
gr.Image(type="pil", label="Annotated OMR Sheet (Detected Marks in Green)")
],
title="Advanced NEET OMR Checker (Gemini + OpenCV)",
description=description,
# allow_flagging="never", # Removed deprecated parameter
examples=examples,
cache_examples=False # Disable caching if processing is complex/stateful
)
# Launch the interface for Hugging Face Spaces
if __name__ == "__main__":
iface.launch(ssr_mode=False) # Disable SSR for potentially better stability