| | |
| | """Flipkart Frontend.ipynb |
| | |
| | Automatically generated by Colab. |
| | |
| | Original file is located at |
| | https://colab.research.google.com/github/Abhinav-gh/404NotFound/blob/main/Flipkart%20Frontend.ipynb |
| | |
| | # 1. Install Gradio and Required Libraries |
| | ### Start by installing Gradio if it's not already installed. |
| | """ |
| |
|
| | """# 2. Import Libraries |
| | ### Getting all the necessary Libraries |
| | """ |
| |
|
| | import gradio as gr |
| | import random |
| | import numpy as np |
| | from PIL import Image |
| | import cv2 |
| | import time |
| | from ultralytics import YOLO |
| | import pandas as pd |
| | from collections import defaultdict, deque |
| | import torch |
| | from torchvision import transforms, models, datasets, transforms |
| | from torch.utils.data import DataLoader |
| | import torch.nn as nn |
| | import matplotlib.pyplot as plt |
| | import google.generativeai as genai |
| | from datetime import datetime |
| | from paddleocr import PaddleOCR |
| | import os |
| | import re |
| |
|
| | """# Path Variables |
| | |
| | ### Path used in OCR |
| | """ |
| |
|
| | |
| | GOOGLE_API_KEY = os.getenv("GEMINI_API") |
| | |
| |
|
| |
|
| | |
| | |
| |
|
| | """# 4. Brand Recognition Backend |
| | |
| | ### Model for Grocery Detection |
| | """ |
| |
|
| | """### Image uploading for Grocery detection""" |
| |
|
| | def detect_grocery_items(image): |
| | model = YOLO('kitkat_s.pt') |
| | image = np.array(image)[:, :, ::-1] |
| | results = model(image) |
| | annotated_image = results[0].plot() |
| |
|
| | class_ids = results[0].boxes.cls.cpu().numpy() |
| | confidences = results[0].boxes.conf.cpu().numpy() |
| |
|
| | threshold = 0.4 |
| | class_counts = {} |
| | class_confidences = {} |
| |
|
| | for i, class_id in enumerate(class_ids): |
| | confidence = confidences[i] |
| | if confidence >= threshold: |
| | class_name = model.names[int(class_id)] |
| |
|
| | if class_name in class_counts: |
| | class_counts[class_name] += 1 |
| | else: |
| | class_counts[class_name] = 1 |
| |
|
| | if class_name in class_confidences: |
| | class_confidences[class_name].append(confidence) |
| | else: |
| | class_confidences[class_name] = [confidence] |
| |
|
| | if not class_counts: |
| | return image, [], "The model failed to recognize items or the image may contain untrained objects." |
| |
|
| | summary_table = [[class_name, count, f"{np.mean(class_confidences[class_name]):.2f}"] |
| | for class_name, count in class_counts.items()] |
| |
|
| | annotated_image_rgb = annotated_image[:, :, ::-1] |
| | return annotated_image_rgb, summary_table, "Object Recognised Successfully 🥳 " |
| |
|
| | """### Detect Grovcery brand from video""" |
| |
|
| | def iou(box1, box2): |
| | x1 = max(box1[0], box2[0]) |
| | y1 = max(box1[1], box2[1]) |
| | x2 = min(box1[2], box2[2]) |
| | y2 = min(box1[3], box2[3]) |
| |
|
| | intersection = max(0, x2 - x1) * max(0, y2 - y1) |
| | area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) |
| | area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) |
| |
|
| | iou = intersection / float(area1 + area2 - intersection) |
| | return iou |
| |
|
| | def smooth_box(box_history): |
| | if not box_history: |
| | return None |
| | return np.mean(box_history, axis=0) |
| |
|
| | def process_video(input_path, output_path): |
| | model = YOLO('kitkat_s.pt') |
| | cap = cv2.VideoCapture(input_path) |
| |
|
| | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| | fps = int(cap.get(cv2.CAP_PROP_FPS)) |
| |
|
| | fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| | out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
| |
|
| | detected_items = {} |
| | frame_count = 0 |
| |
|
| | detections_history = defaultdict(lambda: defaultdict(int)) |
| |
|
| | while cap.isOpened(): |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| |
|
| | frame_count += 1 |
| |
|
| | if frame_count % 5 == 0: |
| | results = model(frame) |
| |
|
| | current_frame_detections = [] |
| |
|
| | for r in results: |
| | boxes = r.boxes |
| | for box in boxes: |
| | x1, y1, x2, y2 = box.xyxy[0].tolist() |
| | conf = box.conf.item() |
| | cls = int(box.cls.item()) |
| | brand = model.names[cls] |
| |
|
| | current_frame_detections.append((brand, [x1, y1, x2, y2], conf)) |
| |
|
| | for brand, box, conf in current_frame_detections: |
| | matched = False |
| | for item_id, item_info in detected_items.items(): |
| | if iou(box, item_info['smoothed_box']) > 0.5: |
| | item_info['frames_detected'] += 1 |
| | item_info['total_conf'] += conf |
| | item_info['box_history'].append(box) |
| | if len(item_info['box_history']) > 10: |
| | item_info['box_history'].popleft() |
| | item_info['smoothed_box'] = smooth_box(item_info['box_history']) |
| | item_info['last_seen'] = frame_count |
| | matched = True |
| | break |
| |
|
| | if not matched: |
| | item_id = len(detected_items) |
| | detected_items[item_id] = { |
| | 'brand': brand, |
| | 'box_history': deque([box], maxlen=10), |
| | 'smoothed_box': box, |
| | 'frames_detected': 1, |
| | 'total_conf': conf, |
| | 'last_seen': frame_count |
| | } |
| |
|
| | detections_history[brand][frame_count] += 1 |
| |
|
| |
|
| | for item_id, item_info in list(detected_items.items()): |
| | if frame_count - item_info['last_seen'] > fps * 2: |
| | del detected_items[item_id] |
| | continue |
| |
|
| | if item_info['smoothed_box'] is not None: |
| | alpha = 0.3 |
| | current_box = item_info['smoothed_box'] |
| | target_box = item_info['box_history'][-1] if item_info['box_history'] else current_box |
| | interpolated_box = [ |
| | current_box[i] * (1 - alpha) + target_box[i] * alpha |
| | for i in range(4) |
| | ] |
| | item_info['smoothed_box'] = interpolated_box |
| |
|
| | x1, y1, x2, y2 = map(int, interpolated_box) |
| | cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) |
| | cv2.putText(frame, f"{item_info['brand']}", |
| | (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) |
| |
|
| | out.write(frame) |
| |
|
| | cap.release() |
| | out.release() |
| |
|
| | total_frames = frame_count |
| | confirmed_items = {} |
| | for brand, frame_counts in detections_history.items(): |
| | detection_frames = len(frame_counts) |
| | if detection_frames > total_frames * 0.1: |
| | avg_count = sum(frame_counts.values()) / detection_frames |
| | confirmed_items[brand] = round(avg_count) |
| |
|
| | return confirmed_items |
| |
|
| | def annotate_video(input_video): |
| | output_path = 'annotated_output.mp4' |
| | confirmed_items = process_video(input_video, output_path) |
| |
|
| | item_list = [(brand, quantity) for brand, quantity in confirmed_items.items()] |
| |
|
| | status_message = "Video processed successfully!" |
| |
|
| | return output_path, item_list, status_message |
| |
|
| | """# 5. OCR Backend |
| | |
| | ### The PaddleOCR + Gemini combined type model. |
| | |
| | Run these 3 cells before trying out any model |
| | """ |
| |
|
| | def new_draw_bounding_boxes(image): |
| | """Draw bounding boxes around detected text in the image and display it.""" |
| | try: |
| | |
| | if isinstance(image, str): |
| | img = Image.open(image) |
| | np_img = np.array(img) |
| | print("[DEBUG] Loaded image from file path.") |
| | elif isinstance(image, Image.Image): |
| | np_img = np.array(image) |
| | print("[DEBUG] Converted PIL Image to NumPy array.") |
| | else: |
| | raise ValueError("Input must be a file path or a PIL Image object.") |
| |
|
| | |
| | ocr_result = ocr.ocr(np_img, cls=True) |
| | print("[DEBUG] OCR Result:\n", ocr_result) |
| |
|
| | |
| | plt.figure(figsize=(10, 10)) |
| | plt.imshow(image) |
| | ax = plt.gca() |
| | all_text_data = [] |
| |
|
| | |
| | for idx, line in enumerate(ocr_result[0]): |
| | box = line[0] |
| | text = line[1][0] |
| | print(f"[DEBUG] Box {idx + 1}: {text}") |
| | all_text_data.append(text) |
| |
|
| | |
| | polygon = plt.Polygon(box, fill=None, edgecolor='red', linewidth=2) |
| | ax.add_patch(polygon) |
| |
|
| | |
| | x, y = box[0][0], box[0][1] |
| | ax.text(x, y - 5, f"{idx + 1}: {text}", color='blue', fontsize=12, ha='left') |
| |
|
| | plt.axis('off') |
| | plt.title("Detected Text with Bounding Boxes", fontsize=16) |
| | plt.show() |
| |
|
| | return all_text_data |
| |
|
| | except Exception as e: |
| | print(f"[ERROR] Error in new_draw_bounding_boxes: {e}") |
| | return [] |
| |
|
| | genai.configure(api_key=GOOGLE_API_KEY) |
| |
|
| | def gemini_context_correction(text): |
| | """Use Gemini API to refine noisy OCR results and extract MRP details.""" |
| | model = genai.GenerativeModel('models/gemini-1.5-flash') |
| |
|
| | response = model.generate_content( |
| | f"Identify and extract manufacturing, expiration dates, and MRP from the following text. " |
| | f"The dates may be written in dd/mm/yyyy format or as <Month_name> <Year> or <day> <Month_Name> <Year>. " |
| | f"The text may contain noise or unclear information. If only one date is provided, assume it is the Expiration Date. " |
| | f"Additionally, extract the MRP (e.g., 'MRP: ₹99.00', 'Rs. 99/-'). " |
| | f"Format the output as:\n" |
| | f"Manufacturing Date: <MFG Date> Expiration Date: <EXP Date> MRP: <MRP Value>" |
| | f"Do **not** generate example text or assumptions." |
| | f"Here is the text: {text}" |
| | ) |
| |
|
| |
|
| | return response.text |
| |
|
| |
|
| |
|
| | def validate_dates_with_gemini(mfg_date, exp_date): |
| | """Use Gemini API to validate and correct the manufacturing and expiration dates.""" |
| | model = genai.GenerativeModel('models/gemini-1.5-flash') |
| | response = model.generate_content = ( |
| | f"Input Manufacturing Date: {mfg_date}, Expiration Date: {exp_date}. " |
| | f"If either date is '-1', leave it as is. " |
| | f"1. If the expiration date is earlier than the manufacturing date, swap them. " |
| | f"2. If both dates are logically incorrect, suggest new valid dates based on typical timeframes. " |
| | f"Always respond ONLY in the format:\n" |
| | f"Manufacturing Date: <MFG Date>, Expiration Date: <EXP Date>" |
| | ) |
| |
|
| | |
| | if response.parts: |
| | |
| | final_dates = response.parts[0].text.strip() |
| | return final_dates |
| |
|
| | |
| | return "Invalid response from Gemini API." |
| |
|
| |
|
| | def extract_and_validate_with_gemini(refined_text): |
| | """ |
| | Use Gemini API to extract, validate, correct, and swap dates in 'yyyy/mm/dd' format if necessary. |
| | """ |
| | model = genai.GenerativeModel('models/gemini-1.5-flash') |
| |
|
| | |
| | response = model.generate_content( |
| | f"The extracted text is:\n'{refined_text}'\n\n" |
| | f"1. Extract the 'Manufacturing Date', 'Expiration Date', and 'MRP' from the above text. " |
| | f"Ignore unrelated data.\n" |
| | f"2. If a date or MRP is missing or invalid, return -1 for that field.\n" |
| | f"3. If the 'Expiration Date' is earlier than the 'Manufacturing Date', swap them.\n" |
| | f"4. Ensure both dates are in 'dd/mm/yyyy' format. If the original dates are not in this format, convert them. " |
| | f"However, if the dates are in 'mm/yyyy' format (without a day), leave them as is and return in 'mm/yyyy' format. " |
| | f"If the dates do not have a day, return them in 'mm/yyyy' format.\n" |
| | f"5. MRP should be returned in the format 'INR <amount>'. If not found or invalid, return 'INR -1'.\n" |
| | f"Respond ONLY in this exact format:\n" |
| | f"Manufacturing Date: <MFG Date>\n" |
| | f"Expiration Date: <EXP Date>\n" |
| | f"MRP: <MRP>" |
| | ) |
| |
|
| | |
| | if hasattr(response, 'parts') and response.parts: |
| | final_dates = response.parts[0].text.strip() |
| | print(f"[DEBUG] Gemini Response: {final_dates}") |
| |
|
| | |
| | mfg_date_str, exp_date_str, mrp_str = parse_gemini_response(final_dates) |
| |
|
| | |
| | if mfg_date_str != "-1" and exp_date_str != "-1": |
| | |
| | mfg_date = parse_date(mfg_date_str) |
| | exp_date = parse_date(exp_date_str) |
| |
|
| | |
| | swapping_statement = "" |
| | if exp_date < mfg_date: |
| | print("[DEBUG] Swapping dates.") |
| | mfg_date, exp_date = exp_date, mfg_date |
| | swapping_statement = "Corrected Dates: \n" |
| |
|
| | |
| | return swapping_statement + ( |
| | f"Manufacturing Date: {format_date(mfg_date)}, " |
| | f"Expiration Date: {format_date(exp_date)}\n" |
| | f"MRP: {mrp_str}" |
| | ) |
| |
|
| | |
| | return final_dates |
| |
|
| | |
| | print("[ERROR] Invalid response from Gemini API.") |
| | return "Invalid response from Gemini API." |
| |
|
| | def parse_gemini_response(response_text): |
| | """ |
| | Helper function to extract Manufacturing Date and Expiration Date from the response text. |
| | """ |
| | try: |
| | |
| | parts = response_text.split(", ") |
| | mfg_date_str = parts[0].split(": ")[1].strip() |
| | exp_date_str = parts[1].split(": ")[1].strip() |
| | mrp_str = parts[2].split(": ")[1].strip() if len(parts) > 2 else "INR -1" |
| | return mfg_date_str, exp_date_str, mrp_str |
| | except IndexError: |
| | print("[ERROR] Failed to parse Gemini response.") |
| | return "-1", "-1", "INR -1" |
| |
|
| | def parse_date(date_str): |
| | """Parse date string to datetime object considering possible formats.""" |
| | if '/' in date_str: |
| | parts = date_str.split('/') |
| | if len(parts) == 3: |
| | return datetime.strptime(date_str, "%d/%m/%Y") |
| | elif len(parts) == 2: |
| | return datetime.strptime(date_str, "%m/%Y") |
| | return datetime.strptime(date_str, "%d/%m/%Y") |
| |
|
| | def format_date(date): |
| | """Format date back to string.""" |
| | if date.day == 1: |
| | return date.strftime('%m/%Y') |
| | return date.strftime('%d/%m/%Y') |
| |
|
| |
|
| | def extract_date(refined_text, date_type): |
| | """Extract the specified date type from the refined text.""" |
| | if date_type in refined_text: |
| | try: |
| | |
| | parts = refined_text.split(',') |
| | for part in parts: |
| | if date_type in part: |
| | return part.split(':')[1].strip() |
| | except IndexError: |
| | return '-1' |
| | return '-1' |
| |
|
| | def extract_details_from_validated_output(validated_output): |
| | """Extract manufacturing date, expiration date, and MRP from the validated output.""" |
| | |
| | pattern = ( |
| | r"Manufacturing Date:\s*([\d\/]+)\s*" |
| | r"Expiration Date:\s*([\d\/]+)\s*" |
| | r"MRP:\s*INR\s*([\d\.]+)" |
| | ) |
| |
|
| | print("[DEBUG] Validated Output:", validated_output) |
| |
|
| | match = re.search(pattern, validated_output) |
| |
|
| | if match: |
| | mfg_date = match.group(1) |
| | exp_date = match.group(2) |
| | mrp = f"INR {match.group(3)}" |
| |
|
| | print("[DEBUG] Extracted Manufacturing Date:", mfg_date) |
| | print("[DEBUG] Extracted Expiration Date:", exp_date) |
| | print("[DEBUG] Extracted MRP:", mrp) |
| | else: |
| | print("[ERROR] No match found for the specified pattern.") |
| | mfg_date, exp_date, mrp = "Not Found", "Not Found", "INR -1" |
| |
|
| | return [ |
| | ["Manufacturing Date", mfg_date], |
| | ["Expiration Date", exp_date], |
| | ["MRP", mrp] |
| | ] |
| |
|
| | """### **Model 3** |
| | Using Yolov8 x-large model trained till about 75 epochs |
| | and |
| | Gradio as user interface |
| | (in case model fails, we fall back to the approach from model 1) |
| | |
| | """ |
| |
|
| | def new_draw_bounding_boxes(image): |
| | """Draw bounding boxes around detected text in the image and display it.""" |
| | |
| | if isinstance(image, str): |
| | img = Image.open(image) |
| | np_img = np.array(img) |
| | ocr_result = ocr.ocr(np_img, cls=True) |
| | elif isinstance(image, Image.Image): |
| | np_img = np.array(image) |
| | ocr_result = ocr.ocr(np_img, cls=True) |
| | else: |
| | raise ValueError("Input must be a file path or a PIL Image object.") |
| |
|
| | |
| | plt.figure(figsize=(10, 10)) |
| | plt.imshow(image) |
| | ax = plt.gca() |
| | all_text_data = [] |
| |
|
| | |
| | for idx, line in enumerate(ocr_result[0]): |
| | box = line[0] |
| | text = line[1][0] |
| | print(f"[DEBUG] Box {idx + 1}: {text}") |
| | all_text_data.append(text) |
| |
|
| | |
| | polygon = plt.Polygon(box, fill=None, edgecolor='red', linewidth=2) |
| | ax.add_patch(polygon) |
| |
|
| | |
| | x, y = box[0][0], box[0][1] |
| | ax.text(x, y - 5, f"{idx + 1}: {text}", color='blue', fontsize=12, ha='left') |
| |
|
| | plt.axis('off') |
| | plt.title("Detected Text with Bounding Boxes", fontsize=16) |
| | plt.show() |
| |
|
| | return all_text_data |
| |
|
| |
|
| | |
| | ocr = PaddleOCR(use_angle_cls=True, lang='en') |
| |
|
| | def detect_and_ocr(image): |
| | model = YOLO('best.pt') |
| |
|
| | """Detect objects using YOLO, draw bounding boxes, and perform OCR.""" |
| | |
| | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
| |
|
| | |
| | results = model(image) |
| | boxes = results[0].boxes.xyxy.cpu().numpy() |
| |
|
| | extracted_texts = [] |
| | for (x1, y1, x2, y2) in boxes: |
| | |
| | cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2) |
| |
|
| | |
| | region = image[int(y1):int(y2), int(x1):int(x2)] |
| | ocr_result = ocr.ocr(region, cls=True) |
| |
|
| | |
| | if ocr_result and isinstance(ocr_result, list) and ocr_result[0]: |
| | for idx, line in enumerate(ocr_result[0]): |
| | box = line[0] |
| | text = line[1][0] |
| | print(f"[DEBUG] Box {idx + 1}: {text}") |
| | extracted_texts.append(text) |
| | else: |
| | |
| | print(f"[DEBUG] No OCR result for region: ({x1}, {y1}, {x2}, {y2}) or OCR returned None") |
| | extracted_texts.append("No OCR result found") |
| |
|
| | |
| | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
|
| | |
| | result_text = "\n".join(str(text) for text in extracted_texts) |
| |
|
| | |
| | refined_text = gemini_context_correction(result_text) |
| | print("[DEBUG] Gemini Refined Text:\n", refined_text) |
| |
|
| | |
| | validated_output = extract_and_validate_with_gemini(refined_text) |
| |
|
| | print("[DEBUG] Validated Output from Gemini:\n", validated_output) |
| |
|
| | |
| | return image_rgb, result_text, refined_text, validated_output |
| |
|
| | def further_processing(image, previous_result_text): |
| | bounding_boxes_list = new_draw_bounding_boxes(image) |
| | print("[DEBUG] ", bounding_boxes_list, type(bounding_boxes_list)) |
| | combined_text = previous_result_text |
| | for text in bounding_boxes_list: |
| | combined_text += text |
| | combined_text += "\n" |
| | print("[DEBUG] combined text", combined_text) |
| | |
| | refined_output = gemini_context_correction(combined_text) |
| | print("[DEBUG] Gemini Refined Output:\n", refined_output) |
| |
|
| | return refined_output |
| |
|
| | def handle_processing(validated_output): |
| | """Decide whether to proceed with further processing.""" |
| | |
| | try: |
| | mfg_date_str = validated_output.split("Manufacturing Date: ")[1].split("\n")[0].strip() |
| | exp_date_str = validated_output.split("Expiration Date: ")[1].split("\n")[0].strip() |
| | mrp_str = validated_output.split("MRP: ")[1].strip() |
| |
|
| | |
| | if mfg_date_str == "-1": |
| | mfg_date = -1 |
| | else: |
| | |
| | if '/' in mfg_date_str: |
| | mfg_date = mfg_date_str |
| | else: |
| | mfg_date = -1 |
| |
|
| | |
| | if exp_date_str == "-1": |
| | exp_date = -1 |
| | else: |
| | |
| | if '/' in exp_date_str: |
| | exp_date = exp_date_str |
| | else: |
| | exp_date = -1 |
| |
|
| | |
| | if mrp_str == "INR -1": |
| | mrp = -1 |
| | else: |
| | |
| | if mrp_str.startswith("INR "): |
| | mrp = mrp_str.split("INR ")[1].strip() |
| | else: |
| | mrp = -1 |
| |
|
| | print("Further processing: ", mfg_date, exp_date, mrp) |
| |
|
| | except IndexError as e: |
| | print(f"[ERROR] Failed to parse validated output: {e}") |
| | return gr.update(visible=False) |
| |
|
| | |
| | if mfg_date == -1 and exp_date == -1 and mrp == -1: |
| | print("[DEBUG] Showing the 'Further Processing' button.") |
| | return gr.update(visible=True) |
| |
|
| | print("[DEBUG] Hiding the 'Further Processing' button.") |
| | return gr.update(visible=False) |
| |
|
| |
|
| | """# 5. Freshness backend |
| | |
| | """ |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | class EfficientNet_FeatureExtractor(nn.Module): |
| |
|
| | def __init__(self): |
| | super(EfficientNet_FeatureExtractor, self).__init__() |
| | self.efficientnet = models.efficientnet_b0(pretrained=True) |
| | self.efficientnet = nn.Sequential(*list(self.efficientnet.children())[:-1]) |
| |
|
| | def forward(self, x): |
| | x = self.efficientnet(x) |
| | x = x.view(x.size(0), -1) |
| |
|
| | return x |
| | |
| |
|
| | transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | ]) |
| |
|
| | dataset = datasets.ImageFolder(root='Datasets/Bananas/Dataset', transform=transform) |
| |
|
| | |
| | loader = DataLoader(dataset, batch_size=32, shuffle=False) |
| |
|
| | |
| | mean = 0.0 |
| | std = 0.0 |
| | total_images = 0 |
| |
|
| | |
| | for images, _ in loader: |
| | batch_samples = images.size(0) |
| | images = images.view(batch_samples, images.size(1), -1) |
| |
|
| | |
| | mean += images.mean(2).sum(0) |
| | std += images.std(2).sum(0) |
| | total_images += batch_samples |
| |
|
| | |
| | mean /= total_images |
| | std /= total_images |
| |
|
| | print(f"Mean: {mean}") |
| | print(f"Std: {std}") |
| |
|
| |
|
| | |
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=mean, std=std) |
| | ]) |
| |
|
| | test_dataset = datasets.ImageFolder(root='Datasets/Bananas/Dataset', transform=transform) |
| |
|
| | |
| | def extract_features(test_dataset): |
| |
|
| | |
| | model = EfficientNet_FeatureExtractor().to(device) |
| | model.eval() |
| |
|
| | |
| | test_loader = DataLoader(test_dataset, batch_size=50, shuffle=False) |
| |
|
| |
|
| | |
| | all_features = [] |
| |
|
| | |
| | with torch.no_grad(): |
| | for images, _ in test_loader: |
| | |
| | images = images.to(device) |
| |
|
| | |
| | features = model(images) |
| |
|
| | |
| | features = features.cpu().numpy() |
| |
|
| | |
| | all_features.append(features) |
| | return all_features |
| |
|
| | all_features = extract_features(test_dataset) |
| |
|
| | |
| | for i, features in enumerate(all_features): |
| | print(f"Shape of batch {i}: {features.shape}") |
| |
|
| | |
| |
|
| | |
| | all_features_tensor = torch.cat([torch.tensor(batch) for batch in all_features], dim=0) |
| |
|
| | |
| | feature_mean = all_features_tensor.mean(dim=0) |
| | feature_mean = feature_mean.to(device) |
| | feature_variance = all_features_tensor.var(dim=0) |
| |
|
| | print(f"Feature Mean Shape: {feature_mean.shape}") |
| |
|
| | all_features_tensor = torch.cat([torch.tensor(f) for f in all_features], dim=0) |
| | all_features_tensor = all_features_tensor.to(device) |
| | feature_mean_temp = all_features_tensor.mean(dim=0) |
| | centered_features = all_features_tensor - feature_mean_temp |
| |
|
| | |
| | |
| | covariance_matrix = torch.cov(centered_features.T) |
| | covariance_matrix = covariance_matrix.to(device) |
| |
|
| | print(f"All Feature Tensor Shape: {all_features_tensor.shape}") |
| | print(f"Covariance Matrix Shape: {covariance_matrix.shape}") |
| |
|
| | |
| |
|
| |
|
| | def mahalanobis(x=None, feature_mean=None, feature_cov=None): |
| | """Compute the Mahalanobis Distance between each row of x and the data |
| | x : tensor of shape [batch_size, num_features], feature vectors of test data |
| | feature_mean : tensor of shape [num_features], mean of the training feature vectors |
| | feature_cov : tensor of shape [num_features, num_features], covariance matrix of the training feature vectors |
| | """ |
| |
|
| | |
| | x_minus_mu = x - feature_mean |
| |
|
| | |
| | inv_covmat = torch.inverse(feature_cov) |
| |
|
| | |
| | left_term = torch.matmul(x_minus_mu, inv_covmat) |
| | mahal = torch.matmul(left_term, x_minus_mu.T) |
| | return mahal.diag() |
| |
|
| |
|
| | transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=mean, std=std) |
| | ]) |
| |
|
| | def classify_banana_by_distance(distance): |
| | """ |
| | Classifies the banana's freshness based on the Mahalanobis distance. |
| | |
| | Args: |
| | distance (float): Mahalanobis distance of the banana. |
| | |
| | Returns: |
| | dict: A dictionary containing the classification and relevant details. |
| | """ |
| |
|
| | |
| | if distance >= 9: |
| | |
| | return { |
| | "Classification": "Completely Fresh", |
| | "Freshness Index": 10, |
| | "Color": "Mostly yellow, little to no brown spots", |
| | "Dark Spots": "0-10%", |
| | "Shelf Life": "5-7 days", |
| | "Ripeness Stage": "Just ripe", |
| | "Texture": "Firm and smooth" |
| | } |
| | elif -90 <= distance < 0: |
| | |
| | return { |
| | "Classification": "Moderately Ripe", |
| | "Freshness Index": 6, |
| | "Color": "60% yellow, 40% dark spots", |
| | "Dark Spots": "40% dark spots", |
| | "Shelf Life": "2-3 days", |
| | "Ripeness Stage": "Moderately ripe", |
| | "Texture": "Some softness, still edible" |
| | } |
| | else: |
| | |
| | return { |
| | "Classification": "Almost Rotten", |
| | "Freshness Index": 2, |
| | "Color": "Mostly brown or black, very few yellow patches", |
| | "Dark Spots": "80-100% dark spots", |
| | "Shelf Life": "0-1 days", |
| | "Ripeness Stage": "Overripe", |
| | "Texture": "Very soft, mushy, may leak moisture" |
| | } |
| |
|
| | return result |
| | def classify_banana(image): |
| |
|
| | model = EfficientNet_FeatureExtractor().to(device) |
| | model.eval() |
| |
|
| | |
| | img = Image.fromarray(image) |
| | img_transformed = transform(img).unsqueeze(0).to(device) |
| |
|
| | |
| | with torch.no_grad(): |
| | features = model(img_transformed) |
| |
|
| | |
| | distance = mahalanobis(features, feature_mean, covariance_matrix) |
| | distance = (distance) / 1e8 |
| |
|
| | return classify_banana_by_distance(distance) |
| | def detect_objects(image): |
| |
|
| |
|
| | |
| | model = YOLO('Yash_Best.pt') |
| | |
| | result = model(image) |
| |
|
| | |
| | img = result[0].orig_img |
| |
|
| | |
| | if result[0].boxes is not None: |
| | for i, box in enumerate(result[0].boxes.xyxy): |
| | x1, y1, x2, y2 = map(int, box[:4]) |
| | conf = result[0].boxes.conf[i].item() |
| | cls = int(result[0].boxes.cls[i].item()) |
| |
|
| | |
| | label = f'{result[0].names[cls]} {conf:.2f}' |
| |
|
| | |
| | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) |
| | cv2.putText(img, label, (x1, y1 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2) |
| |
|
| | |
| | img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| |
|
| | return img_rgb |
| |
|
| | def detect_objects_video(video_file): |
| | model = YOLO('Yash_Best.pt') |
| | |
| | cap = cv2.VideoCapture(video_file) |
| |
|
| | |
| | if not cap.isOpened(): |
| | raise Exception("Could not open video file.") |
| |
|
| | |
| | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| | fps = int(cap.get(cv2.CAP_PROP_FPS)) |
| |
|
| | |
| | output_video_path = 'output_detected_video.mp4' |
| | fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| | out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) |
| |
|
| | |
| | while cap.isOpened(): |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| |
|
| | |
| | results = model(frame) |
| |
|
| | |
| | if results[0].boxes is not None: |
| | for i, box in enumerate(results[0].boxes.xyxy): |
| | x1, y1, x2, y2 = map(int, box[:4]) |
| | conf = results[0].boxes.conf[i].item() |
| | cls = int(results[0].boxes.cls[i].item()) |
| | label = f'{results[0].names[cls]} {conf:.2f}' |
| |
|
| | |
| | cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) |
| | cv2.putText(frame, label, (x1, y1 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2) |
| |
|
| | |
| | out.write(frame) |
| |
|
| | |
| | cap.release() |
| | out.release() |
| |
|
| | return output_video_path |
| |
|
| |
|
| | """# 5. Frontend Of Brand Recognition |
| | |
| | ## Layout for Image interface |
| | """ |
| |
|
| | def create_image_interface(): |
| | return gr.Interface( |
| | fn=detect_grocery_items, |
| | inputs=gr.Image(label="Upload Image", height=400, width=400), |
| | outputs=[ |
| | gr.Image(label="Image with Bounding Boxes", height=400, width=400), |
| | gr.Dataframe(headers=["Item", "Quantity", "Avg Confidence"], label="Detected Items and Quantities", elem_id="summary_table"), |
| | gr.Textbox(label="Status", elem_id="status_message") |
| | ], |
| | title="Grocery Item Detection in an Image", |
| | description="Upload an image for object detection. The model will return an annotated image, item quantities, and average confidence scores.", |
| | css=".gr-table { font-size: 16px; text-align: left; width: 50%; margin: auto; } #summary_table { margin-top: 20px; }" |
| | ) |
| |
|
| | """## Layout For Video Interface""" |
| |
|
| | def create_video_interface(): |
| | return gr.Interface( |
| | fn=annotate_video, |
| | inputs=gr.Video(label="Upload Video", height=400, width=400), |
| | outputs=[ |
| | gr.Video(label="Annotated Video", height=400, width=400), |
| | gr.Dataframe(headers=["Item", "Quantity"], label="Detected Items and Quantities", elem_id="summary_table"), |
| | gr.Textbox(label="Status", elem_id="status_message") |
| | ], |
| | title="Grocery Item Detection in a Video", |
| | description="Upload a video for object detection. The model will return an annotated video with bounding boxes and item quantities. Low confidence values may indicate incorrect detection.", |
| | css=""" |
| | .gr-table { font-size: 16px; text-align: left; width: 50%; margin: auto; } |
| | #summary_table { margin-top: 20px; } |
| | """ |
| | ) |
| |
|
| | def create_brand_recog_interface(): |
| | with gr.Blocks() as demo: |
| | gr.Markdown("# Flipkart Grid Robotics Track - Brand Recognition Interface") |
| |
|
| | with gr.Tabs(): |
| | with gr.Tab("Image"): |
| | create_image_interface() |
| | with gr.Tab("Video"): |
| | create_video_interface() |
| | return demo |
| |
|
| | Brand_recog = create_brand_recog_interface() |
| |
|
| | """# Frontend Of OCR""" |
| |
|
| | def create_ocr_interface(): |
| | with gr.Blocks() as ocr_interface: |
| | gr.Markdown("# Flipkart Grid Robotics Track - OCR Interface") |
| |
|
| | with gr.Tabs(): |
| | |
| | with gr.TabItem("Upload & Detection"): |
| | with gr.Row(): |
| | input_image = gr.Image(type="pil", label="Upload Image", height=400, width=400) |
| | output_image = gr.Image(label="Image with Bounding Boxes", height=400, width=400) |
| |
|
| | btn = gr.Button("Analyze Image & Extract Text") |
| |
|
| | |
| | with gr.TabItem("OCR Results"): |
| | with gr.Row(): |
| | extracted_textbox = gr.Textbox(label="Extracted OCR Text", lines=5) |
| | with gr.Row(): |
| | refined_textbox = gr.Textbox(label="Refined Text from Gemini", lines=5) |
| | with gr.Row(): |
| | validated_textbox = gr.Textbox(label="Validated Output", lines=5) |
| |
|
| | |
| | with gr.Row(): |
| | detail_table = gr.Dataframe( |
| | headers=["Label", "Value"], |
| | value=[["", ""], ["", ""], ["", ""]], |
| | label="Manufacturing, Expiration Dates & MRP", |
| | datatype=["str", "str"], |
| | interactive=False, |
| | ) |
| |
|
| | further_button = gr.Button("Comprehensive OCR", visible=False) |
| |
|
| | |
| | btn.click( |
| | detect_and_ocr, |
| | inputs=[input_image], |
| | outputs=[output_image, extracted_textbox, refined_textbox, validated_textbox] |
| | ).then( |
| | lambda: gr.update(visible=True), |
| | outputs=[detail_table] |
| | ) |
| |
|
| | |
| | validated_textbox.change( |
| | lambda validated_output: extract_details_from_validated_output(validated_output), |
| | inputs=[validated_textbox], |
| | outputs=[detail_table] |
| | ) |
| |
|
| | |
| | further_button.click( |
| | further_processing, |
| | inputs=[input_image, extracted_textbox], |
| | outputs=refined_textbox |
| | ) |
| |
|
| | |
| | refined_textbox.change( |
| | handle_processing, |
| | inputs=[validated_textbox], |
| | outputs=[further_button] |
| | ) |
| |
|
| | further_button.click( |
| | lambda: gr.update(visible=False), |
| | outputs=[detail_table] |
| | ) |
| |
|
| | return ocr_interface |
| |
|
| | |
| | ocr_interface = create_ocr_interface() |
| |
|
| | """ 6. Front End of Fruit Index |
| | """ |
| | def create_banana_classifier_interface(): |
| | return gr.Interface( |
| | fn=classify_banana, |
| | inputs=gr.Image(type="numpy", label="Upload a Banana Image"), |
| | outputs=gr.JSON(label="Classification Result"), |
| | title="Banana Freshness Classifier", |
| | description="Upload an image of a banana to classify its freshness.", |
| | css="#component-0 { width: 300px; height: 300px; }" |
| | ) |
| |
|
| | def image_freshness_interface(): |
| | return gr.Interface( |
| | fn=detect_objects, |
| | inputs=gr.Image(type="pil", label="Upload an Image"), |
| | outputs=gr.Image(type="pil", label="Detected Image"), |
| | live=True, |
| | title="Image Freshness Detection", |
| | description="Upload an image of fruit to detect freshness.", |
| | css="#component-0 { width: 300px; height: 300px; }" |
| | ) |
| |
|
| | def video_freshness_interface(): |
| | return gr.Interface( |
| | fn=detect_objects_video, |
| | inputs=gr.Video(label="Upload a Video"), |
| | outputs=[ |
| | gr.Video(label="Processed Video"), |
| | ], |
| | title="Video Freshness Detection", |
| | description="Upload a video of fruit to detect freshness.", |
| | css="#component-0 { width: 300px; height: 300px; }" |
| | ) |
| |
|
| | def create_fruit_interface(): |
| | with gr.Blocks() as demo: |
| | gr.Markdown("# Flipkart Grid Robotics Track - Fruits Interface") |
| | with gr.Tabs(): |
| | with gr.Tab("Banana"): |
| | create_banana_classifier_interface() |
| | with gr.Tab("Image Freshness"): |
| | image_freshness_interface() |
| | with gr.Tab("Video Freshness"): |
| | video_freshness_interface() |
| | return demo |
| |
|
| |
|
| | Fruit = create_fruit_interface() |
| |
|
| | |
| | |
| |
|
| | def create_tabbed_interface(): |
| | return gr.TabbedInterface( |
| | [Brand_recog, ocr_interface,Fruit ], |
| | ["Brand Recongnition", "OCR" , "Fruit Freshness"] |
| | ) |
| |
|
| | tabbed_interface = create_tabbed_interface() |
| |
|
| | """# 7. Launch the Gradio Interface |
| | ### Finally, launch the Gradio interface to make it interactable. |
| | """ |
| |
|
| | tabbed_interface.launch(debug=False) |
| |
|