Spaces:
Sleeping
Sleeping
File size: 1,456 Bytes
9419ab1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# app.py
import os
import torch
import numpy as np
from PIL import Image
import joblib
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
# --- Load CLIP Model and Processor ---
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# --- Load Trained SVM Model ---
svm_model = joblib.load("svm_phone_view_model.joblib")
# --- Label Mapping ---
label_map = {0: "back", 1: "bottom", 2: "front", 3: "top"}
# --- Function to Extract CLIP Embedding ---
def extract_clip_embedding(image):
inputs = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
features = clip_model.get_image_features(**inputs)
return features.squeeze().numpy()
# --- Gradio prediction function ---
def predict_image_view(image):
embedding = extract_clip_embedding(image)
probs = svm_model.predict_proba([embedding])[0]
pred_index = np.argmax(probs)
prediction = label_map[pred_index]
confidence = probs[pred_index] * 100
return f"View: {prediction.upper()} ({confidence:.2f}%)"
# --- Launch Gradio interface ---
demo = gr.Interface(
fn=predict_image_view,
inputs=gr.Image(type="pil"),
outputs="text",
title="Phone View Classifier (4-class)",
description="Upload an image of a phone and classify it as one of: Front, Back, Top, Bottom"
)
if __name__ == "__main__":
demo.launch()
|