Spaces:
Sleeping
Sleeping
File size: 5,814 Bytes
4dd1fe2 82663dd 4dd1fe2 82663dd 4dd1fe2 82663dd 751b8f9 4dd1fe2 01e138f 82663dd 4dd1fe2 82663dd 28c504c 7b0cc34 0feb61f 4dd1fe2 751b8f9 94fde9b bf89789 2cb41dd 4dd1fe2 82663dd 4dd1fe2 c94cb62 eb93cf6 751b8f9 0feb61f bf89789 751b8f9 bf89789 751b8f9 bf89789 9f307f3 bf89789 0feb61f bf89789 861df2d bf89789 0f3b6f1 bf89789 861df2d bf89789 94fde9b bf89789 4dd1fe2 edefda0 28887b8 4dd1fe2 28887b8 4dd1fe2 28887b8 4dd1fe2 28887b8 4dd1fe2 28887b8 4dd1fe2 82663dd bf89789 82663dd 4dd1fe2 82663dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | import os
import torch
import numpy as np
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from torchvision import transforms
import easyocr
from transformers import CLIPProcessor, CLIPModel
from huggingface_hub import hf_hub_download
HF_TOKEN = os.environ.get("HF_TOKEN")
model = None
clip_processor = None
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classes = ['.ipynb_checkpoints', '2-products-in-one-offer',
'2-products-in-one-offer-+-coupon',
'an-offer-with-a-coupon',
'availability-of-additional-products',
'offers-with-a-preliminary-promotional-price',
'offers-with-an-additional-deposit-price',
'offers-with-an-additional-shipping',
'offers-with-dealtype-special_price',
'offers-with-different-sizes',
'offers-with-money_rebate',
'offers-with-percentage_rebate',
'offers-with-price-characteristic-(statt)',
'offers-with-price-characterization-(uvp)',
'offers-with-product-number-(sku)',
'offers-with-reward',
'offers-with-the-condition_-available-from-such-and-such-a-number',
'offers-with-the-old-price-crossed-out',
'regular',
'scene-with-multiple-offers-+-uvp-price-for-each-offers',
'several-products-in-one-offer-with-different-prices',
'simple-offers',
'stock-offers',
'stocks',
'travel-booklets',
'with-a-product-without-a-price',
'with-the-price-of-the-supplemental-bid']
# Custom CLIP-based Multimodal Classifier
class CLIPMultimodalClassifier(torch.nn.Module):
def __init__(self, num_classes):
super(CLIPMultimodalClassifier, self).__init__()
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.fc = torch.nn.Linear(self.clip_model.config.projection_dim, num_classes)
def forward(self, images, texts):
image_features = self.clip_model.get_image_features(images)
text_features = self.clip_model.get_text_features(texts)
combined_features = (image_features + text_features) / 2
logits = self.fc(combined_features)
return logits
# Image preprocessing (resize and normalize)
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# EasyOCR Reader for text extraction
# ocr_reader = easyocr.Reader(['en', 'de'], gpu=False) # Supports English and German
ocr_reader = easyocr.Reader(['en', 'de'], model_storage_directory="./")
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
num_classes = len(classes)
model = CLIPMultimodalClassifier(num_classes)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Inference function
@spaces.GPU()
def run_inference(image, model, clip_processor):
# global model, clip_processor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# # Initialize model if not already loaded
# if model is None or clip_processor is None:
# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# model_path = hf_hub_download(repo_id="limitedonly41/offers_26",
# filename="multi_train_best_model.pth",
# use_auth_token=HF_TOKEN)
# num_classes = len(classes)
# model = CLIPMultimodalClassifier(num_classes).to(device)
# model.load_state_dict(torch.load(model_path, map_location=device))
# model.eval()
# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# model_path = hf_hub_download(repo_id="limitedonly41/offers_26",
# filename="multi_train_best_model.pth",
# use_auth_token=HF_TOKEN)
model_path = hf_hub_download(repo_id="limitedonly41/offers_26",
filename="continued_training_model_5.pth",
use_auth_token=HF_TOKEN)
num_classes = len(classes)
model = model.to(device)
# model = CLIPMultimodalClassifier(num_classes).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
image_pil = Image.fromarray(image).convert("RGB")
# Preprocess image
image_tensor = image_transform(image_pil).unsqueeze(0).to(device)
# Extract text using EasyOCR
ocr_text = ocr_reader.readtext(image, detail=0)
combined_text = " ".join(ocr_text) # Join OCR results into one string
# Preprocess text for CLIP
text_inputs = clip_processor(
text=[combined_text], # Text in a list
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=77
).to(device)
# Predict
with torch.no_grad():
outputs = model(image_tensor, text_inputs["input_ids"])
probabilities = torch.nn.functional.softmax(outputs, dim=1)
predicted_class_idx = torch.argmax(probabilities, dim=1).item()
# Return results
predicted_class = classes[predicted_class_idx]
return f"Predicted Class: {predicted_class}\nExtracted Text: {combined_text}"
# Create a Gradio interface
iface = gr.Interface(
fn=lambda image: run_inference(image, model, clip_processor),
inputs=gr.Image(type="numpy"), # Updated to use gr.Image
outputs="text", # Output is text (predicted class)
title="Image Classification",
description="Upload an image to get the predicted class using the ViT model."
)
# Launch the Gradio app
iface.launch() |