Ut14 commited on
Commit
9419ab1
·
verified ·
1 Parent(s): 622c65d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +46 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import joblib
7
+ import gradio as gr
8
+ from transformers import CLIPProcessor, CLIPModel
9
+
10
+ # --- Load CLIP Model and Processor ---
11
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
12
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
13
+
14
+ # --- Load Trained SVM Model ---
15
+ svm_model = joblib.load("svm_phone_view_model.joblib")
16
+
17
+ # --- Label Mapping ---
18
+ label_map = {0: "back", 1: "bottom", 2: "front", 3: "top"}
19
+
20
+ # --- Function to Extract CLIP Embedding ---
21
+ def extract_clip_embedding(image):
22
+ inputs = clip_processor(images=image, return_tensors="pt")
23
+ with torch.no_grad():
24
+ features = clip_model.get_image_features(**inputs)
25
+ return features.squeeze().numpy()
26
+
27
+ # --- Gradio prediction function ---
28
+ def predict_image_view(image):
29
+ embedding = extract_clip_embedding(image)
30
+ probs = svm_model.predict_proba([embedding])[0]
31
+ pred_index = np.argmax(probs)
32
+ prediction = label_map[pred_index]
33
+ confidence = probs[pred_index] * 100
34
+ return f"View: {prediction.upper()} ({confidence:.2f}%)"
35
+
36
+ # --- Launch Gradio interface ---
37
+ demo = gr.Interface(
38
+ fn=predict_image_view,
39
+ inputs=gr.Image(type="pil"),
40
+ outputs="text",
41
+ title="Phone View Classifier (4-class)",
42
+ description="Upload an image of a phone and classify it as one of: Front, Back, Top, Bottom"
43
+ )
44
+
45
+ if __name__ == "__main__":
46
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ scikit-learn>=1.3.0
4
+ joblib>=1.3.2
5
+ Pillow>=9.5.0
6
+ numpy>=1.24.0
7
+ tqdm>=4.65.0
8
+ gradio