Update app.py
Browse files
app.py
CHANGED
|
@@ -1,22 +1,15 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
-
"""
|
| 3 |
|
| 4 |
Automatically generated by Colab.
|
| 5 |
|
| 6 |
Original file is located at
|
| 7 |
-
https://colab.research.google.com/
|
| 8 |
|
| 9 |
# 1. Install Gradio and Required Libraries
|
| 10 |
### Start by installing Gradio if it's not already installed.
|
| 11 |
"""
|
| 12 |
|
| 13 |
-
# ! pip install gradio
|
| 14 |
-
# ! pip install cv
|
| 15 |
-
# ! pip install ultralytics
|
| 16 |
-
# ! pip install supervision
|
| 17 |
-
# !pip install google-generativeai
|
| 18 |
-
# !pip install paddleocr
|
| 19 |
-
# !pip install paddlepaddle
|
| 20 |
|
| 21 |
"""# 2. Import Libraries
|
| 22 |
### Getting all the necessary Libraries
|
|
@@ -29,48 +22,35 @@ from PIL import Image
|
|
| 29 |
import cv2
|
| 30 |
import time
|
| 31 |
from ultralytics import YOLO
|
| 32 |
-
import supervision as sv
|
| 33 |
import pandas as pd
|
|
|
|
| 34 |
from collections import defaultdict, deque
|
| 35 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
import google.generativeai as genai
|
| 37 |
from datetime import datetime
|
| 38 |
from paddleocr import PaddleOCR
|
| 39 |
import os
|
| 40 |
|
| 41 |
-
"""# Path Variables
|
| 42 |
-
|
| 43 |
-
### Path used in OCR
|
| 44 |
-
"""
|
| 45 |
-
|
| 46 |
-
OCR_M3="best.pt"
|
| 47 |
-
GOOGLE_API_KEY = os.getenv("GEMINI_API")
|
| 48 |
-
GEMINI_MODEL = 'models/gemini-1.5-flash'
|
| 49 |
-
|
| 50 |
-
"""### Path used in Brand Recognition model"""
|
| 51 |
-
|
| 52 |
-
Brand_Recognition_Model ='kitkat_s.pt'
|
| 53 |
-
annotatedOpFile= 'annotated_output.mp4'
|
| 54 |
-
|
| 55 |
"""# 3. Import Drive
|
| 56 |
|
| 57 |
"""
|
| 58 |
|
| 59 |
# from google.colab import drive
|
| 60 |
-
|
| 61 |
# drive.mount('/content/drive')
|
| 62 |
|
| 63 |
"""# 4. Brand Recognition Backend
|
| 64 |
|
| 65 |
-
###
|
| 66 |
"""
|
| 67 |
|
| 68 |
-
model_path = Brand_Recognition_Model
|
| 69 |
-
model = YOLO(model_path)
|
| 70 |
-
|
| 71 |
-
"""### Image uploading for Grocery detection"""
|
| 72 |
-
|
| 73 |
def detect_grocery_items(image):
|
|
|
|
| 74 |
image = np.array(image)[:, :, ::-1]
|
| 75 |
results = model(image)
|
| 76 |
annotated_image = results[0].plot()
|
|
@@ -109,7 +89,6 @@ def detect_grocery_items(image):
|
|
| 109 |
"""### Detect Grovcery brand from video"""
|
| 110 |
|
| 111 |
def iou(box1, box2):
|
| 112 |
-
# Calculate intersection over union
|
| 113 |
x1 = max(box1[0], box2[0])
|
| 114 |
y1 = max(box1[1], box2[1])
|
| 115 |
x2 = min(box1[2], box2[2])
|
|
@@ -128,22 +107,19 @@ def smooth_box(box_history):
|
|
| 128 |
return np.mean(box_history, axis=0)
|
| 129 |
|
| 130 |
def process_video(input_path, output_path):
|
|
|
|
| 131 |
cap = cv2.VideoCapture(input_path)
|
| 132 |
|
| 133 |
-
# Get video properties
|
| 134 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 135 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 136 |
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 137 |
|
| 138 |
-
# Initialize video writer
|
| 139 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 140 |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 141 |
|
| 142 |
-
# Initialize variables for tracking
|
| 143 |
detected_items = {}
|
| 144 |
frame_count = 0
|
| 145 |
|
| 146 |
-
# For result confirmation
|
| 147 |
detections_history = defaultdict(lambda: defaultdict(int))
|
| 148 |
|
| 149 |
while cap.isOpened():
|
|
@@ -153,7 +129,6 @@ def process_video(input_path, output_path):
|
|
| 153 |
|
| 154 |
frame_count += 1
|
| 155 |
|
| 156 |
-
# Run YOLO detection every 5th frame
|
| 157 |
if frame_count % 5 == 0:
|
| 158 |
results = model(frame)
|
| 159 |
|
|
@@ -169,7 +144,6 @@ def process_video(input_path, output_path):
|
|
| 169 |
|
| 170 |
current_frame_detections.append((brand, [x1, y1, x2, y2], conf))
|
| 171 |
|
| 172 |
-
# Match current detections with existing items
|
| 173 |
for brand, box, conf in current_frame_detections:
|
| 174 |
matched = False
|
| 175 |
for item_id, item_info in detected_items.items():
|
|
@@ -203,7 +177,6 @@ def process_video(input_path, output_path):
|
|
| 203 |
del detected_items[item_id]
|
| 204 |
continue
|
| 205 |
|
| 206 |
-
# Interpolate box position
|
| 207 |
if item_info['smoothed_box'] is not None:
|
| 208 |
alpha = 0.3
|
| 209 |
current_box = item_info['smoothed_box']
|
|
@@ -224,7 +197,6 @@ def process_video(input_path, output_path):
|
|
| 224 |
cap.release()
|
| 225 |
out.release()
|
| 226 |
|
| 227 |
-
# Calculate final counts and confirm results
|
| 228 |
total_frames = frame_count
|
| 229 |
confirmed_items = {}
|
| 230 |
for brand, frame_counts in detections_history.items():
|
|
@@ -236,7 +208,7 @@ def process_video(input_path, output_path):
|
|
| 236 |
return confirmed_items
|
| 237 |
|
| 238 |
def annotate_video(input_video):
|
| 239 |
-
output_path =
|
| 240 |
confirmed_items = process_video(input_video, output_path)
|
| 241 |
|
| 242 |
item_list = [(brand, quantity) for brand, quantity in confirmed_items.items()]
|
|
@@ -281,6 +253,7 @@ def draw_bounding_boxes(image_path):
|
|
| 281 |
return all_text_data
|
| 282 |
|
| 283 |
# Set your API key securely (store it in Colab’s userdata)
|
|
|
|
| 284 |
genai.configure(api_key=GOOGLE_API_KEY)
|
| 285 |
|
| 286 |
def gemini_context_correction(text):
|
|
@@ -301,14 +274,9 @@ def gemini_context_correction(text):
|
|
| 301 |
|
| 302 |
return response.text
|
| 303 |
|
| 304 |
-
# Test Gemini with example text (replace with actual OCR output)
|
| 305 |
-
sample_text = "EXP 12/2024 MFD 08/2023 Best Before 06/2025 MRP Rs. 250/-"
|
| 306 |
-
refined_output = gemini_context_correction(sample_text)
|
| 307 |
-
print("[DEBUG] Gemini Refined Output:\n", refined_output)
|
| 308 |
-
|
| 309 |
def validate_dates_with_gemini(mfg_date, exp_date):
|
| 310 |
"""Use Gemini API to validate and correct the manufacturing and expiration dates."""
|
| 311 |
-
model = genai.GenerativeModel(
|
| 312 |
response = model.generate_content = (
|
| 313 |
f"Input Manufacturing Date: {mfg_date}, Expiration Date: {exp_date}. "
|
| 314 |
f"If either date is '-1', leave it as is. "
|
|
@@ -332,7 +300,7 @@ def extract_and_validate_with_gemini(refined_text):
|
|
| 332 |
"""
|
| 333 |
Use Gemini API to extract, validate, and correct manufacturing and expiration dates.
|
| 334 |
"""
|
| 335 |
-
model = genai.GenerativeModel(
|
| 336 |
|
| 337 |
# Correctly call the generate_content method
|
| 338 |
response = model.generate_content(
|
|
@@ -381,7 +349,7 @@ def extract_and_validate_with_gemini(refined_text):
|
|
| 381 |
"""
|
| 382 |
Use Gemini API to extract, validate, correct, and swap dates in 'yyyy/mm/dd' format if necessary.
|
| 383 |
"""
|
| 384 |
-
model = genai.GenerativeModel(
|
| 385 |
|
| 386 |
# Generate content using Gemini with the refined prompt
|
| 387 |
response = model.generate_content(
|
|
@@ -455,9 +423,6 @@ def extract_date(refined_text, date_type):
|
|
| 455 |
return '-1' # Return -1 if the date is not found
|
| 456 |
return '-1' # Return -1 if the date type is not in the text
|
| 457 |
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
"""### **Model 3**
|
| 462 |
Using Yolov8 x-large model trained till about 75 epochs
|
| 463 |
and
|
|
@@ -466,9 +431,7 @@ Gradio as user interface
|
|
| 466 |
|
| 467 |
"""
|
| 468 |
|
| 469 |
-
|
| 470 |
-
model = YOLO(model_path)
|
| 471 |
-
|
| 472 |
"""## Driver code to be run after selecting from Model 2 or 3.
|
| 473 |
(Note: not needed for model 1)
|
| 474 |
"""
|
|
@@ -601,6 +564,310 @@ def handle_processing(validated_output):
|
|
| 601 |
print("[DEBUG] Hiding the 'Further Processing' button.") # Debug print
|
| 602 |
return gr.update(visible=False) # Hide button if dates are valid
|
| 603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
"""# 5. Frontend Of Brand Recognition
|
| 605 |
|
| 606 |
## Layout for Image interface
|
|
@@ -712,14 +979,66 @@ def create_ocr_interface():
|
|
| 712 |
ocr_interface = create_ocr_interface()
|
| 713 |
# ocr_interface.launch(share=True, debug=True)
|
| 714 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
"""# 6. Create a Tabbed Interface for Both Image and Video
|
| 716 |
### Here, we combine the image and video interfaces into a tabbed structure so users can switch between them easily.
|
| 717 |
"""
|
| 718 |
|
| 719 |
def create_tabbed_interface():
|
| 720 |
return gr.TabbedInterface(
|
| 721 |
-
[Brand_recog, ocr_interface ],
|
| 722 |
-
["Brand Recongnition", "OCR"]
|
| 723 |
)
|
| 724 |
|
| 725 |
tabbed_interface = create_tabbed_interface()
|
|
@@ -728,4 +1047,4 @@ tabbed_interface = create_tabbed_interface()
|
|
| 728 |
### Finally, launch the Gradio interface to make it interactable.
|
| 729 |
"""
|
| 730 |
|
| 731 |
-
tabbed_interface.launch()
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
+
"""Complete_3_model_code.ipynb
|
| 3 |
|
| 4 |
Automatically generated by Colab.
|
| 5 |
|
| 6 |
Original file is located at
|
| 7 |
+
https://colab.research.google.com/drive/1Ivlv1jHXwoldi9Mb-quvDBlCNeIIAhc9
|
| 8 |
|
| 9 |
# 1. Install Gradio and Required Libraries
|
| 10 |
### Start by installing Gradio if it's not already installed.
|
| 11 |
"""
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
"""# 2. Import Libraries
|
| 15 |
### Getting all the necessary Libraries
|
|
|
|
| 22 |
import cv2
|
| 23 |
import time
|
| 24 |
from ultralytics import YOLO
|
|
|
|
| 25 |
import pandas as pd
|
| 26 |
+
from IPython.display import clear_output
|
| 27 |
from collections import defaultdict, deque
|
| 28 |
import matplotlib.pyplot as plt
|
| 29 |
+
import torch
|
| 30 |
+
from PIL import Image
|
| 31 |
+
from torchvision import transforms, models, datasets, transforms
|
| 32 |
+
from torch.utils.data import DataLoader
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
import matplotlib.pyplot as plt
|
| 35 |
import google.generativeai as genai
|
| 36 |
from datetime import datetime
|
| 37 |
from paddleocr import PaddleOCR
|
| 38 |
import os
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
"""# 3. Import Drive
|
| 41 |
|
| 42 |
"""
|
| 43 |
|
| 44 |
# from google.colab import drive
|
|
|
|
| 45 |
# drive.mount('/content/drive')
|
| 46 |
|
| 47 |
"""# 4. Brand Recognition Backend
|
| 48 |
|
| 49 |
+
### Image uploading for Grocery detection
|
| 50 |
"""
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def detect_grocery_items(image):
|
| 53 |
+
model = YOLO('kitkat_s.pt')
|
| 54 |
image = np.array(image)[:, :, ::-1]
|
| 55 |
results = model(image)
|
| 56 |
annotated_image = results[0].plot()
|
|
|
|
| 89 |
"""### Detect Grovcery brand from video"""
|
| 90 |
|
| 91 |
def iou(box1, box2):
|
|
|
|
| 92 |
x1 = max(box1[0], box2[0])
|
| 93 |
y1 = max(box1[1], box2[1])
|
| 94 |
x2 = min(box1[2], box2[2])
|
|
|
|
| 107 |
return np.mean(box_history, axis=0)
|
| 108 |
|
| 109 |
def process_video(input_path, output_path):
|
| 110 |
+
model = YOLO('kitkat_n.pt')
|
| 111 |
cap = cv2.VideoCapture(input_path)
|
| 112 |
|
|
|
|
| 113 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 114 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 115 |
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 116 |
|
|
|
|
| 117 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 118 |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 119 |
|
|
|
|
| 120 |
detected_items = {}
|
| 121 |
frame_count = 0
|
| 122 |
|
|
|
|
| 123 |
detections_history = defaultdict(lambda: defaultdict(int))
|
| 124 |
|
| 125 |
while cap.isOpened():
|
|
|
|
| 129 |
|
| 130 |
frame_count += 1
|
| 131 |
|
|
|
|
| 132 |
if frame_count % 5 == 0:
|
| 133 |
results = model(frame)
|
| 134 |
|
|
|
|
| 144 |
|
| 145 |
current_frame_detections.append((brand, [x1, y1, x2, y2], conf))
|
| 146 |
|
|
|
|
| 147 |
for brand, box, conf in current_frame_detections:
|
| 148 |
matched = False
|
| 149 |
for item_id, item_info in detected_items.items():
|
|
|
|
| 177 |
del detected_items[item_id]
|
| 178 |
continue
|
| 179 |
|
|
|
|
| 180 |
if item_info['smoothed_box'] is not None:
|
| 181 |
alpha = 0.3
|
| 182 |
current_box = item_info['smoothed_box']
|
|
|
|
| 197 |
cap.release()
|
| 198 |
out.release()
|
| 199 |
|
|
|
|
| 200 |
total_frames = frame_count
|
| 201 |
confirmed_items = {}
|
| 202 |
for brand, frame_counts in detections_history.items():
|
|
|
|
| 208 |
return confirmed_items
|
| 209 |
|
| 210 |
def annotate_video(input_video):
|
| 211 |
+
output_path = 'annotated_output.mp4'
|
| 212 |
confirmed_items = process_video(input_video, output_path)
|
| 213 |
|
| 214 |
item_list = [(brand, quantity) for brand, quantity in confirmed_items.items()]
|
|
|
|
| 253 |
return all_text_data
|
| 254 |
|
| 255 |
# Set your API key securely (store it in Colab’s userdata)
|
| 256 |
+
GOOGLE_API_KEY= os.getenv("GEMINI_API")
|
| 257 |
genai.configure(api_key=GOOGLE_API_KEY)
|
| 258 |
|
| 259 |
def gemini_context_correction(text):
|
|
|
|
| 274 |
|
| 275 |
return response.text
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
def validate_dates_with_gemini(mfg_date, exp_date):
|
| 278 |
"""Use Gemini API to validate and correct the manufacturing and expiration dates."""
|
| 279 |
+
model = genai.GenerativeModel('models/gemini-1.5-flash')
|
| 280 |
response = model.generate_content = (
|
| 281 |
f"Input Manufacturing Date: {mfg_date}, Expiration Date: {exp_date}. "
|
| 282 |
f"If either date is '-1', leave it as is. "
|
|
|
|
| 300 |
"""
|
| 301 |
Use Gemini API to extract, validate, and correct manufacturing and expiration dates.
|
| 302 |
"""
|
| 303 |
+
model = genai.GenerativeModel('models/gemini-1.5-flash')
|
| 304 |
|
| 305 |
# Correctly call the generate_content method
|
| 306 |
response = model.generate_content(
|
|
|
|
| 349 |
"""
|
| 350 |
Use Gemini API to extract, validate, correct, and swap dates in 'yyyy/mm/dd' format if necessary.
|
| 351 |
"""
|
| 352 |
+
model = genai.GenerativeModel('models/gemini-1.5-flash')
|
| 353 |
|
| 354 |
# Generate content using Gemini with the refined prompt
|
| 355 |
response = model.generate_content(
|
|
|
|
| 423 |
return '-1' # Return -1 if the date is not found
|
| 424 |
return '-1' # Return -1 if the date type is not in the text
|
| 425 |
|
|
|
|
|
|
|
|
|
|
| 426 |
"""### **Model 3**
|
| 427 |
Using Yolov8 x-large model trained till about 75 epochs
|
| 428 |
and
|
|
|
|
| 431 |
|
| 432 |
"""
|
| 433 |
|
| 434 |
+
model = YOLO('best.pt')
|
|
|
|
|
|
|
| 435 |
"""## Driver code to be run after selecting from Model 2 or 3.
|
| 436 |
(Note: not needed for model 1)
|
| 437 |
"""
|
|
|
|
| 564 |
print("[DEBUG] Hiding the 'Further Processing' button.") # Debug print
|
| 565 |
return gr.update(visible=False) # Hide button if dates are valid
|
| 566 |
|
| 567 |
+
"""# Freshness Backend"""
|
| 568 |
+
|
| 569 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class EfficientNet_FeatureExtractor(nn.Module):
|
| 573 |
+
|
| 574 |
+
def __init__(self):
|
| 575 |
+
super(EfficientNet_FeatureExtractor, self).__init__()
|
| 576 |
+
self.efficientnet = models.efficientnet_b0(pretrained=True)
|
| 577 |
+
self.efficientnet = nn.Sequential(*list(self.efficientnet.children())[:-1])
|
| 578 |
+
|
| 579 |
+
def forward(self, x):
|
| 580 |
+
x = self.efficientnet(x)
|
| 581 |
+
x = x.view(x.size(0), -1)
|
| 582 |
+
|
| 583 |
+
return x
|
| 584 |
+
|
| 585 |
+
# Calculating the mean and variance of the images whose features will be extracted
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
transform = transforms.Compose([
|
| 589 |
+
transforms.Resize(256),
|
| 590 |
+
transforms.CenterCrop(224),
|
| 591 |
+
transforms.ToTensor(),
|
| 592 |
+
])
|
| 593 |
+
|
| 594 |
+
dataset = datasets.ImageFolder(root='Datasets/Bananas', transform=transform)
|
| 595 |
+
|
| 596 |
+
# Create a DataLoader
|
| 597 |
+
loader = DataLoader(dataset, batch_size=32, shuffle=False)
|
| 598 |
+
|
| 599 |
+
# Initialize variables to calculate the mean and std
|
| 600 |
+
mean = 0.0
|
| 601 |
+
std = 0.0
|
| 602 |
+
total_images = 0
|
| 603 |
+
|
| 604 |
+
# Iterate over the dataset to compute mean and std
|
| 605 |
+
for images, _ in loader:
|
| 606 |
+
batch_samples = images.size(0)
|
| 607 |
+
images = images.view(batch_samples, images.size(1), -1) # Flatten each image (C, H*W)
|
| 608 |
+
|
| 609 |
+
# Calculate mean and std for this batch and add to the running total
|
| 610 |
+
mean += images.mean(2).sum(0)
|
| 611 |
+
std += images.std(2).sum(0)
|
| 612 |
+
total_images += batch_samples
|
| 613 |
+
|
| 614 |
+
# Final mean and std across all images in the dataset
|
| 615 |
+
mean /= total_images
|
| 616 |
+
std /= total_images
|
| 617 |
+
|
| 618 |
+
print(f"Mean: {mean}")
|
| 619 |
+
print(f"Std: {std}")
|
| 620 |
+
|
| 621 |
+
# Transforming the images into the format so that they can be passes through the EfficientNet model
|
| 622 |
+
# Define the transform for your dataset, including normalization with custom mean and std
|
| 623 |
+
transform = transforms.Compose([
|
| 624 |
+
transforms.Resize(256),
|
| 625 |
+
transforms.CenterCrop(224),
|
| 626 |
+
transforms.ToTensor(),
|
| 627 |
+
transforms.Normalize(mean=mean, std=std)
|
| 628 |
+
])
|
| 629 |
+
|
| 630 |
+
test_dataset = datasets.ImageFolder(root='/Dataset/Bananas', transform=transform)
|
| 631 |
+
|
| 632 |
+
# Extracting features from Efficientnet model
|
| 633 |
+
def extract_features(test_dataset):
|
| 634 |
+
|
| 635 |
+
# Initialize the feature extractor model
|
| 636 |
+
model = EfficientNet_FeatureExtractor().to(device)
|
| 637 |
+
model.eval() # Set to evaluation mode
|
| 638 |
+
|
| 639 |
+
# Create a DataLoader for the test dataset
|
| 640 |
+
test_loader = DataLoader(test_dataset, batch_size=50, shuffle=False)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
# Store the extracted features
|
| 644 |
+
all_features = []
|
| 645 |
+
|
| 646 |
+
# Loop over the test dataset and extract features
|
| 647 |
+
with torch.no_grad(): # Disable gradient calculation for efficiency
|
| 648 |
+
for images, _ in test_loader:
|
| 649 |
+
# Send the images to the same device as the model
|
| 650 |
+
images = images.to(device)
|
| 651 |
+
|
| 652 |
+
# Pass the images through the feature extractor
|
| 653 |
+
features = model(images)
|
| 654 |
+
|
| 655 |
+
# Move features to CPU and convert to NumPy (optional)
|
| 656 |
+
features = features.cpu().numpy()
|
| 657 |
+
|
| 658 |
+
# Append the features for further use
|
| 659 |
+
all_features.append(features)
|
| 660 |
+
return all_features
|
| 661 |
+
|
| 662 |
+
all_features = extract_features(test_dataset)
|
| 663 |
+
|
| 664 |
+
# Print the shape of each batch stored in the list
|
| 665 |
+
for i, features in enumerate(all_features):
|
| 666 |
+
print(f"Shape of batch {i}: {features.shape}")
|
| 667 |
+
|
| 668 |
+
# Calculating the mean and varinance of the entire distribution
|
| 669 |
+
|
| 670 |
+
# Stack all the feature vectors into a single tensor
|
| 671 |
+
all_features_tensor = torch.cat([torch.tensor(batch) for batch in all_features], dim=0)
|
| 672 |
+
|
| 673 |
+
# Calculate the mean and variance along the feature dimension
|
| 674 |
+
feature_mean = all_features_tensor.mean(dim=0)
|
| 675 |
+
feature_mean = feature_mean.to(device)
|
| 676 |
+
feature_variance = all_features_tensor.var(dim=0)
|
| 677 |
+
|
| 678 |
+
print(f"Feature Mean Shape: {feature_mean.shape}")
|
| 679 |
+
|
| 680 |
+
all_features_tensor = torch.cat([torch.tensor(f) for f in all_features], dim=0)
|
| 681 |
+
all_features_tensor = all_features_tensor.to(device)
|
| 682 |
+
feature_mean_temp = all_features_tensor.mean(dim=0)
|
| 683 |
+
centered_features = all_features_tensor - feature_mean_temp
|
| 684 |
+
|
| 685 |
+
# Calculate the covariance matrix
|
| 686 |
+
# Covariance matrix: (num_features, num_features)
|
| 687 |
+
covariance_matrix = torch.cov(centered_features.T)
|
| 688 |
+
covariance_matrix = covariance_matrix.to(device)
|
| 689 |
+
|
| 690 |
+
print(f"All Feature Tensor Shape: {all_features_tensor.shape}")
|
| 691 |
+
print(f"Covariance Matrix Shape: {covariance_matrix.shape}")
|
| 692 |
+
|
| 693 |
+
# Defining the function to calculate the Mahalanobis distance
|
| 694 |
+
|
| 695 |
+
import torch
|
| 696 |
+
|
| 697 |
+
def mahalanobis(x=None, feature_mean=None, feature_cov=None):
|
| 698 |
+
"""Compute the Mahalanobis Distance between each row of x and the data
|
| 699 |
+
x : tensor of shape [batch_size, num_features], feature vectors of test data
|
| 700 |
+
feature_mean : tensor of shape [num_features], mean of the training feature vectors
|
| 701 |
+
feature_cov : tensor of shape [num_features, num_features], covariance matrix of the training feature vectors
|
| 702 |
+
"""
|
| 703 |
+
|
| 704 |
+
# Subtract the mean from x
|
| 705 |
+
x_minus_mu = x - feature_mean
|
| 706 |
+
|
| 707 |
+
# Invert the covariance matrix
|
| 708 |
+
inv_covmat = torch.inverse(feature_cov)
|
| 709 |
+
|
| 710 |
+
# Mahalanobis distance computation: (x - mu)^T * inv_cov * (x - mu)
|
| 711 |
+
left_term = torch.matmul(x_minus_mu, inv_covmat)
|
| 712 |
+
mahal = torch.matmul(left_term, x_minus_mu.T)
|
| 713 |
+
return mahal.diag()
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
transform = transforms.Compose([
|
| 717 |
+
transforms.Resize(256),
|
| 718 |
+
transforms.CenterCrop(224),
|
| 719 |
+
transforms.ToTensor(),
|
| 720 |
+
transforms.Normalize(mean=mean, std=std)
|
| 721 |
+
])
|
| 722 |
+
|
| 723 |
+
def classify_banana_by_distance(distance):
|
| 724 |
+
"""
|
| 725 |
+
Classifies the banana's freshness based on the Mahalanobis distance.
|
| 726 |
+
|
| 727 |
+
Args:
|
| 728 |
+
distance (float): Mahalanobis distance of the banana.
|
| 729 |
+
|
| 730 |
+
Returns:
|
| 731 |
+
dict: A dictionary containing the classification and relevant details.
|
| 732 |
+
"""
|
| 733 |
+
|
| 734 |
+
# Define thresholds for classification based on the provided distances
|
| 735 |
+
if distance >= 9:
|
| 736 |
+
# Case 1: Completely Fresh Banana
|
| 737 |
+
return {
|
| 738 |
+
"Classification": "Completely Fresh",
|
| 739 |
+
"Freshness Index": 10,
|
| 740 |
+
"Color": "Mostly yellow, little to no brown spots",
|
| 741 |
+
"Dark Spots": "0-10%",
|
| 742 |
+
"Shelf Life": "5-7 days",
|
| 743 |
+
"Ripeness Stage": "Just ripe",
|
| 744 |
+
"Texture": "Firm and smooth"
|
| 745 |
+
}
|
| 746 |
+
elif -90 <= distance < 0:
|
| 747 |
+
# Case 2: Banana with 40% Dark Brown Spots
|
| 748 |
+
return {
|
| 749 |
+
"Classification": "Moderately Ripe",
|
| 750 |
+
"Freshness Index": 6,
|
| 751 |
+
"Color": "60% yellow, 40% dark spots",
|
| 752 |
+
"Dark Spots": "40% dark spots",
|
| 753 |
+
"Shelf Life": "2-3 days",
|
| 754 |
+
"Ripeness Stage": "Moderately ripe",
|
| 755 |
+
"Texture": "Some softness, still edible"
|
| 756 |
+
}
|
| 757 |
+
else:
|
| 758 |
+
# Case 3: Almost Rotten Banana
|
| 759 |
+
return {
|
| 760 |
+
"Classification": "Almost Rotten",
|
| 761 |
+
"Freshness Index": 2,
|
| 762 |
+
"Color": "Mostly brown or black, very few yellow patches",
|
| 763 |
+
"Dark Spots": "80-100% dark spots",
|
| 764 |
+
"Shelf Life": "0-1 days",
|
| 765 |
+
"Ripeness Stage": "Overripe",
|
| 766 |
+
"Texture": "Very soft, mushy, may leak moisture"
|
| 767 |
+
}
|
| 768 |
+
|
| 769 |
+
return result
|
| 770 |
+
|
| 771 |
+
def classify_banana(image):
|
| 772 |
+
|
| 773 |
+
model = EfficientNet_FeatureExtractor().to(device)
|
| 774 |
+
model.eval() # Set to evaluation mode
|
| 775 |
+
|
| 776 |
+
# Load and transform the image
|
| 777 |
+
img = Image.fromarray(image)
|
| 778 |
+
img_transformed = transform(img).unsqueeze(0).to(device)
|
| 779 |
+
|
| 780 |
+
# Feature extraction
|
| 781 |
+
with torch.no_grad():
|
| 782 |
+
features = model(img_transformed)
|
| 783 |
+
|
| 784 |
+
# Calculate Mahalanobis distance
|
| 785 |
+
distance = mahalanobis(features, feature_mean, covariance_matrix)
|
| 786 |
+
distance = (distance) / 1e8
|
| 787 |
+
|
| 788 |
+
return classify_banana_by_distance(distance)
|
| 789 |
+
|
| 790 |
+
"""## Freshness Detect Using image"""
|
| 791 |
+
|
| 792 |
+
def detect_objects(image):
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
# Load the YOLO model
|
| 796 |
+
model = YOLO('Yash_Best.pt')
|
| 797 |
+
# Run inference on the image
|
| 798 |
+
result = model(image)
|
| 799 |
+
|
| 800 |
+
# Get the image from the result
|
| 801 |
+
img = result[0].orig_img # Original image
|
| 802 |
+
|
| 803 |
+
# If bounding boxes are detected, loop over them and draw them
|
| 804 |
+
if result[0].boxes is not None:
|
| 805 |
+
for i, box in enumerate(result[0].boxes.xyxy): # Bounding boxes (x1, y1, x2, y2)
|
| 806 |
+
x1, y1, x2, y2 = map(int, box[:4])
|
| 807 |
+
conf = result[0].boxes.conf[i].item() # Confidence score
|
| 808 |
+
cls = int(result[0].boxes.cls[i].item()) # Class ID
|
| 809 |
+
|
| 810 |
+
# Get the label name
|
| 811 |
+
label = f'{result[0].names[cls]} {conf:.2f}'
|
| 812 |
+
|
| 813 |
+
# Draw the bounding box
|
| 814 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) # Green box
|
| 815 |
+
cv2.putText(img, label, (x1, y1 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
| 816 |
+
|
| 817 |
+
# Convert image to RGB for displaying in Gradio
|
| 818 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 819 |
+
|
| 820 |
+
return img_rgb
|
| 821 |
+
|
| 822 |
+
"""## Freshness Detect using Video"""
|
| 823 |
+
|
| 824 |
+
def detect_objects_video(video_file):
|
| 825 |
+
|
| 826 |
+
# Load the YOLO model
|
| 827 |
+
model = YOLO('Yash_Best.pt')
|
| 828 |
+
# Open the video file
|
| 829 |
+
cap = cv2.VideoCapture(video_file.name)
|
| 830 |
+
|
| 831 |
+
# Get video properties
|
| 832 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 833 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 834 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 835 |
+
|
| 836 |
+
# Output video writer to save the results
|
| 837 |
+
output_video_path = 'output_detected_video.mp4'
|
| 838 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 839 |
+
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
| 840 |
+
|
| 841 |
+
# Process each frame from the video
|
| 842 |
+
while cap.isOpened():
|
| 843 |
+
ret, frame = cap.read()
|
| 844 |
+
if not ret:
|
| 845 |
+
break # Exit if there are no more frames
|
| 846 |
+
|
| 847 |
+
# Run object detection on the frame
|
| 848 |
+
results = model(frame)
|
| 849 |
+
|
| 850 |
+
# Loop over detection results and draw bounding boxes with labels
|
| 851 |
+
if results[0].boxes is not None:
|
| 852 |
+
for i, box in enumerate(results[0].boxes.xyxy): # Bounding boxes (x1, y1, x2, y2)
|
| 853 |
+
x1, y1, x2, y2 = map(int, box[:4])
|
| 854 |
+
conf = results[0].boxes.conf[i].item() # Confidence score
|
| 855 |
+
cls = int(results[0].boxes.cls[i].item()) # Class ID
|
| 856 |
+
label = f'{results[0].names[cls]} {conf:.2f}'
|
| 857 |
+
|
| 858 |
+
# Draw bounding box and label
|
| 859 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 860 |
+
cv2.putText(frame, label, (x1, y1 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
| 861 |
+
|
| 862 |
+
# Write the processed frame to the output video
|
| 863 |
+
out.write(frame)
|
| 864 |
+
|
| 865 |
+
# Release resources
|
| 866 |
+
cap.release()
|
| 867 |
+
out.release()
|
| 868 |
+
|
| 869 |
+
return output_video_path
|
| 870 |
+
|
| 871 |
"""# 5. Frontend Of Brand Recognition
|
| 872 |
|
| 873 |
## Layout for Image interface
|
|
|
|
| 979 |
ocr_interface = create_ocr_interface()
|
| 980 |
# ocr_interface.launch(share=True, debug=True)
|
| 981 |
|
| 982 |
+
|
| 983 |
+
"""# Frontend for Fruit Freshness
|
| 984 |
+
|
| 985 |
+
## Layout for Freshness Index
|
| 986 |
+
"""
|
| 987 |
+
|
| 988 |
+
def create_banana_classifier_interface():
|
| 989 |
+
return gr.Interface(
|
| 990 |
+
fn=classify_banana, # Your classification function
|
| 991 |
+
inputs=gr.Image(type="numpy", label="Upload a Banana Image"), # Removed tool argument
|
| 992 |
+
outputs=gr.JSON(label="Classification Result"),
|
| 993 |
+
title="Banana Freshness Classifier",
|
| 994 |
+
description="Upload an image of a banana to classify its freshness.",
|
| 995 |
+
css="#component-0 { width: 300px; height: 300px; }" # Keep your CSS for fixed size
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
def image_freshness_interface():
|
| 999 |
+
return gr.Interface(
|
| 1000 |
+
fn=detect_objects, # Your detection function
|
| 1001 |
+
inputs=gr.Image(type="numpy", label="Upload an Image"), # Removed tool argument
|
| 1002 |
+
outputs=gr.Image(type="numpy", label="Detected Image"),
|
| 1003 |
+
live=True,
|
| 1004 |
+
title="Image Freshness Detection",
|
| 1005 |
+
description="Upload an image of fruit to detect freshness.",
|
| 1006 |
+
css="#component-0 { width: 300px; height: 300px; }" # Keep your CSS for fixed size
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
def video_freshness_interface():
|
| 1010 |
+
return gr.Interface(
|
| 1011 |
+
fn=process_video, # Your video processing function
|
| 1012 |
+
inputs=gr.Video(label="Upload a Video"),
|
| 1013 |
+
outputs=gr.Video(label="Processed Video"),
|
| 1014 |
+
title="Video Freshness Detection",
|
| 1015 |
+
description="Upload a video of fruit to detect freshness.",
|
| 1016 |
+
css="#component-0 { width: 300px; height: 300px; }" # Keep your CSS for fixed size
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
def create_fruit_interface():
|
| 1020 |
+
with gr.Blocks() as demo:
|
| 1021 |
+
gr.Markdown("# Flipkart Grid Robotics Track - Fruits Interface")
|
| 1022 |
+
with gr.Tabs():
|
| 1023 |
+
with gr.Tab("Banana"):
|
| 1024 |
+
create_banana_classifier_interface() # Call the banana classifier interface
|
| 1025 |
+
with gr.Tab("Image Freshness"):
|
| 1026 |
+
image_freshness_interface() # Call the image freshness interface
|
| 1027 |
+
with gr.Tab("Video Freshness"):
|
| 1028 |
+
video_freshness_interface() # Call the video freshness interface
|
| 1029 |
+
return demo
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
Fruit = create_fruit_interface()
|
| 1033 |
+
|
| 1034 |
"""# 6. Create a Tabbed Interface for Both Image and Video
|
| 1035 |
### Here, we combine the image and video interfaces into a tabbed structure so users can switch between them easily.
|
| 1036 |
"""
|
| 1037 |
|
| 1038 |
def create_tabbed_interface():
|
| 1039 |
return gr.TabbedInterface(
|
| 1040 |
+
[Brand_recog, ocr_interface,Fruit ],
|
| 1041 |
+
["Brand Recongnition", "OCR" , "Fruit Freshness"]
|
| 1042 |
)
|
| 1043 |
|
| 1044 |
tabbed_interface = create_tabbed_interface()
|
|
|
|
| 1047 |
### Finally, launch the Gradio interface to make it interactable.
|
| 1048 |
"""
|
| 1049 |
|
| 1050 |
+
tabbed_interface.launch()
|