Tngarg commited on
Commit
3182e2f
·
verified ·
1 Parent(s): d7fae2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -93
app.py CHANGED
@@ -1,93 +1,95 @@
1
- import os
2
- import torch
3
- import numpy as np
4
- import gradio as gr
5
- from PIL import Image
6
- from transformers import CLIPProcessor, CLIPModel
7
-
8
- # Load model parameters
9
- params = torch.load("clip_classification_params.pth", weights_only=False)
10
-
11
- correct_centroid = torch.tensor(params["correct_centroid"])
12
- incorrect_centroid = torch.tensor(params["incorrect_centroid"])
13
- threshold = params["threshold"]
14
-
15
- # Load CLIP model
16
- MODEL_NAME = "openai/clip-vit-base-patch32"
17
- clip_model = CLIPModel.from_pretrained(MODEL_NAME)
18
- clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
19
- clip_model.eval()
20
-
21
- # Paths to example images
22
- correct_examples = [
23
- "dataset/correct/correct (1).webp",
24
- "dataset/correct/correct (2).webp",
25
- "dataset/correct/correct (3).webp"
26
- ]
27
-
28
- incorrect_examples = [
29
- "dataset/incorrect/incorrect (1).webp",
30
- "dataset/incorrect/incorrect (2).webp",
31
- "dataset/incorrect/incorrect (3).webp"
32
- ]
33
-
34
- # Function to classify an image
35
- def classify_installation(image):
36
- """Classify if the bed installation is correct or incorrect."""
37
- inputs = clip_processor(images=image, return_tensors="pt", padding=True)
38
- with torch.no_grad():
39
- embedding = clip_model.get_image_features(**inputs)
40
- embedding = embedding / embedding.norm(p=2, dim=-1, keepdim=True)
41
-
42
- # Compute similarity to correct centroid
43
- similarity = torch.matmul(embedding, correct_centroid)
44
-
45
- # Compare similarity with threshold
46
- if similarity.item() >= threshold:
47
- return f"✅ Correct Installation (similarity = {similarity.item():.2f})"
48
- else:
49
- return f" Incorrect Installation (similarity = {similarity.item():.2f})"
50
-
51
- # Function to load image from file path
52
- def load_image(image_path):
53
- return Image.open(image_path).convert("RGB")
54
-
55
- # Function to process selected example image
56
- def process_example(image_path):
57
- image = load_image(image_path)
58
- return classify_installation(image), image
59
-
60
- # Gradio UI
61
- with gr.Blocks() as demo:
62
- gr.Markdown("# 🛏️ Bed Installation Classifier")
63
- gr.Markdown("Upload an image or select one from the examples below to check if the bed installation is correct.")
64
-
65
- with gr.Row():
66
- uploaded_image = gr.Image(type="pil", label="Upload an image for testing")
67
- output_text = gr.Textbox(label="Result")
68
-
69
- gr.Markdown("### Check Installations (Click Button to Classify)")
70
- with gr.Row():
71
- correct_buttons = []
72
- for i, img_path in enumerate(correct_examples):
73
- with gr.Column():
74
- gr.Image(value=load_image(img_path), interactive=False, width=150, height=150)
75
- btn = gr.Button(value=f"Check Accuracy")
76
- correct_buttons.append((btn, img_path))
77
-
78
- with gr.Row():
79
- incorrect_buttons = []
80
- for i, img_path in enumerate(incorrect_examples):
81
- with gr.Column():
82
- gr.Image(value=load_image(img_path), interactive=False, width=150, height=150)
83
- btn = gr.Button(value=f"Check Accuracy")
84
- incorrect_buttons.append((btn, img_path))
85
-
86
- # Connect buttons to classification function
87
- for btn, img_path in correct_buttons + incorrect_buttons:
88
- btn.click(fn=process_example, inputs=[gr.State(img_path)], outputs=[output_text, uploaded_image])
89
-
90
- # Process uploaded image
91
- uploaded_image.change(fn=classify_installation, inputs=[uploaded_image], outputs=output_text)
92
-
93
- demo.launch()
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from transformers import CLIPProcessor, CLIPModel
7
+
8
+ # Load model parameters
9
+ params = torch.load("clip_classification_params.pth", weights_only=False)
10
+
11
+ correct_centroid = torch.tensor(params["correct_centroid"])
12
+ incorrect_centroid = torch.tensor(params["incorrect_centroid"])
13
+ threshold = params["threshold"]
14
+
15
+ # Load CLIP model
16
+ MODEL_NAME = "openai/clip-vit-base-patch32"
17
+ clip_model = CLIPModel.from_pretrained(MODEL_NAME)
18
+ clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
19
+ clip_model.eval()
20
+
21
+ # Paths to example images
22
+ correct_examples = [
23
+ "dataset/correct/correct (1).webp",
24
+ "dataset/correct/correct (2).webp",
25
+ "dataset/correct/correct (3).webp",
26
+ "dataset/correct/correct (4).webp"
27
+ ]
28
+
29
+ incorrect_examples = [
30
+ "dataset/incorrect/incorrect (1).webp",
31
+ "dataset/incorrect/incorrect (2).webp",
32
+ "dataset/incorrect/incorrect (3).webp",
33
+ "dataset/incorrect/incorrect (4).webp"
34
+ ]
35
+
36
+ # Function to classify an image
37
+ def classify_installation(image):
38
+ """Classify if the bed installation is correct or incorrect."""
39
+ inputs = clip_processor(images=image, return_tensors="pt", padding=True)
40
+ with torch.no_grad():
41
+ embedding = clip_model.get_image_features(**inputs)
42
+ embedding = embedding / embedding.norm(p=2, dim=-1, keepdim=True)
43
+
44
+ # Compute similarity to correct centroid
45
+ similarity = torch.matmul(embedding, correct_centroid)
46
+
47
+ # Compare similarity with threshold
48
+ if similarity.item() >= threshold:
49
+ return f" Correct Installation (similarity = {similarity.item():.2f})"
50
+ else:
51
+ return f"❌ Incorrect Installation (similarity = {similarity.item():.2f})"
52
+
53
+ # Function to load image from file path
54
+ def load_image(image_path):
55
+ return Image.open(image_path).convert("RGB")
56
+
57
+ # Function to process selected example image
58
+ def process_example(image_path):
59
+ image = load_image(image_path)
60
+ return classify_installation(image), image
61
+
62
+ # Gradio UI
63
+ with gr.Blocks() as demo:
64
+ gr.Markdown("# 🛏️ Bed Installation Classifier")
65
+ gr.Markdown("Upload an image or select one from the examples below to check if the bed installation is correct.")
66
+
67
+ with gr.Row():
68
+ uploaded_image = gr.Image(type="pil", label="Upload an image for testing")
69
+ output_text = gr.Textbox(label="Result")
70
+
71
+ gr.Markdown("### Check Installations (Click Button to Classify)")
72
+ with gr.Row():
73
+ correct_buttons = []
74
+ for i, img_path in enumerate(correct_examples):
75
+ with gr.Column():
76
+ gr.Image(value=load_image(img_path), interactive=False, width=150, height=150)
77
+ btn = gr.Button(value=f"Check Accuracy")
78
+ correct_buttons.append((btn, img_path))
79
+
80
+ with gr.Row():
81
+ incorrect_buttons = []
82
+ for i, img_path in enumerate(incorrect_examples):
83
+ with gr.Column():
84
+ gr.Image(value=load_image(img_path), interactive=False, width=150, height=150)
85
+ btn = gr.Button(value=f"Check Accuracy")
86
+ incorrect_buttons.append((btn, img_path))
87
+
88
+ # Connect buttons to classification function
89
+ for btn, img_path in correct_buttons + incorrect_buttons:
90
+ btn.click(fn=process_example, inputs=[gr.State(img_path)], outputs=[output_text, uploaded_image])
91
+
92
+ # Process uploaded image
93
+ uploaded_image.change(fn=classify_installation, inputs=[uploaded_image], outputs=output_text)
94
+
95
+ demo.launch()