File size: 5,814 Bytes
4dd1fe2
 
82663dd
4dd1fe2
 
82663dd
 
 
 
 
 
4dd1fe2
 
 
82663dd
 
751b8f9
 
4dd1fe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01e138f
 
 
 
 
 
 
 
 
 
 
 
 
 
82663dd
 
 
4dd1fe2
 
 
 
 
 
82663dd
28c504c
7b0cc34
0feb61f
4dd1fe2
751b8f9
94fde9b
 
 
 
bf89789
2cb41dd
 
4dd1fe2
82663dd
4dd1fe2
c94cb62
eb93cf6
 
751b8f9
 
 
0feb61f
bf89789
 
751b8f9
bf89789
751b8f9
 
bf89789
 
 
9f307f3
bf89789
 
 
 
0feb61f
bf89789
 
 
 
 
861df2d
 
 
 
bf89789
0f3b6f1
bf89789
861df2d
bf89789
 
94fde9b
 
bf89789
 
 
 
 
4dd1fe2
edefda0
28887b8
 
4dd1fe2
 
 
28887b8
4dd1fe2
 
 
28887b8
4dd1fe2
 
 
 
 
 
28887b8
4dd1fe2
 
 
 
 
28887b8
4dd1fe2
 
 
 
82663dd
 
bf89789
82663dd
 
 
 
4dd1fe2
 
82663dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import torch
import numpy as np
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from torchvision import transforms


import easyocr
from transformers import CLIPProcessor, CLIPModel
from huggingface_hub import hf_hub_download

HF_TOKEN = os.environ.get("HF_TOKEN")
model = None
clip_processor = None
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


classes = ['.ipynb_checkpoints', '2-products-in-one-offer',
           '2-products-in-one-offer-+-coupon',
           'an-offer-with-a-coupon',
           'availability-of-additional-products',
           'offers-with-a-preliminary-promotional-price',
           'offers-with-an-additional-deposit-price',
           'offers-with-an-additional-shipping',
           'offers-with-dealtype-special_price',
           'offers-with-different-sizes',
           'offers-with-money_rebate',
           'offers-with-percentage_rebate',
           'offers-with-price-characteristic-(statt)',
           'offers-with-price-characterization-(uvp)',
           'offers-with-product-number-(sku)',
           'offers-with-reward',
           'offers-with-the-condition_-available-from-such-and-such-a-number',
           'offers-with-the-old-price-crossed-out',
           'regular',
           'scene-with-multiple-offers-+-uvp-price-for-each-offers',
           'several-products-in-one-offer-with-different-prices',
           'simple-offers',
           'stock-offers',
           'stocks',
           'travel-booklets',
           'with-a-product-without-a-price',
           'with-the-price-of-the-supplemental-bid']

# Custom CLIP-based Multimodal Classifier
class CLIPMultimodalClassifier(torch.nn.Module):
    def __init__(self, num_classes):
        super(CLIPMultimodalClassifier, self).__init__()
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.fc = torch.nn.Linear(self.clip_model.config.projection_dim, num_classes)

    def forward(self, images, texts):
        image_features = self.clip_model.get_image_features(images)
        text_features = self.clip_model.get_text_features(texts)
        combined_features = (image_features + text_features) / 2
        logits = self.fc(combined_features)
        return logits



# Image preprocessing (resize and normalize)
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# EasyOCR Reader for text extraction
# ocr_reader = easyocr.Reader(['en', 'de'], gpu=False)  # Supports English and German
ocr_reader = easyocr.Reader(['en', 'de'], model_storage_directory="./")



# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

num_classes = len(classes)
model = CLIPMultimodalClassifier(num_classes)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")



# Inference function
@spaces.GPU()
def run_inference(image, model, clip_processor):
    
    # global model, clip_processor

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    # # Initialize model if not already loaded
    # if model is None or clip_processor is None:

    #     clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

        
    #     model_path = hf_hub_download(repo_id="limitedonly41/offers_26", 
    #                                  filename="multi_train_best_model.pth", 
    #                                  use_auth_token=HF_TOKEN)

    #     num_classes = len(classes)
    #     model = CLIPMultimodalClassifier(num_classes).to(device)
    #     model.load_state_dict(torch.load(model_path, map_location=device))
    #     model.eval()



    # clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    
    # model_path = hf_hub_download(repo_id="limitedonly41/offers_26", 
    #                              filename="multi_train_best_model.pth", 
    #                              use_auth_token=HF_TOKEN)

    model_path = hf_hub_download(repo_id="limitedonly41/offers_26", 
                                 filename="continued_training_model_5.pth", 
                                 use_auth_token=HF_TOKEN)
    

    num_classes = len(classes)
    model = model.to(device)
    
    # model = CLIPMultimodalClassifier(num_classes).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    
    image_pil = Image.fromarray(image).convert("RGB")

    # Preprocess image
    image_tensor = image_transform(image_pil).unsqueeze(0).to(device)

    # Extract text using EasyOCR
    ocr_text = ocr_reader.readtext(image, detail=0)
    combined_text = " ".join(ocr_text)  # Join OCR results into one string

    # Preprocess text for CLIP
    text_inputs = clip_processor(
        text=[combined_text],  # Text in a list
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=77
    ).to(device)

    # Predict
    with torch.no_grad():
        outputs = model(image_tensor, text_inputs["input_ids"])
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        predicted_class_idx = torch.argmax(probabilities, dim=1).item()

    # Return results
    predicted_class = classes[predicted_class_idx]
    return f"Predicted Class: {predicted_class}\nExtracted Text: {combined_text}"


# Create a Gradio interface
iface = gr.Interface(
    fn=lambda image: run_inference(image, model, clip_processor),
    inputs=gr.Image(type="numpy"),  # Updated to use gr.Image
    outputs="text",  # Output is text (predicted class)
    title="Image Classification",
    description="Upload an image to get the predicted class using the ViT model."
)
# Launch the Gradio app
iface.launch()