Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| from torchvision import transforms | |
| import easyocr | |
| from transformers import CLIPProcessor, CLIPModel | |
| from huggingface_hub import hf_hub_download | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| model = None | |
| clip_processor = None | |
| # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| classes = ['.ipynb_checkpoints', '2-products-in-one-offer', | |
| '2-products-in-one-offer-+-coupon', | |
| 'an-offer-with-a-coupon', | |
| 'availability-of-additional-products', | |
| 'offers-with-a-preliminary-promotional-price', | |
| 'offers-with-an-additional-deposit-price', | |
| 'offers-with-an-additional-shipping', | |
| 'offers-with-dealtype-special_price', | |
| 'offers-with-different-sizes', | |
| 'offers-with-money_rebate', | |
| 'offers-with-percentage_rebate', | |
| 'offers-with-price-characteristic-(statt)', | |
| 'offers-with-price-characterization-(uvp)', | |
| 'offers-with-product-number-(sku)', | |
| 'offers-with-reward', | |
| 'offers-with-the-condition_-available-from-such-and-such-a-number', | |
| 'offers-with-the-old-price-crossed-out', | |
| 'regular', | |
| 'scene-with-multiple-offers-+-uvp-price-for-each-offers', | |
| 'several-products-in-one-offer-with-different-prices', | |
| 'simple-offers', | |
| 'stock-offers', | |
| 'stocks', | |
| 'travel-booklets', | |
| 'with-a-product-without-a-price', | |
| 'with-the-price-of-the-supplemental-bid'] | |
| # Custom CLIP-based Multimodal Classifier | |
| class CLIPMultimodalClassifier(torch.nn.Module): | |
| def __init__(self, num_classes): | |
| super(CLIPMultimodalClassifier, self).__init__() | |
| self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| self.fc = torch.nn.Linear(self.clip_model.config.projection_dim, num_classes) | |
| def forward(self, images, texts): | |
| image_features = self.clip_model.get_image_features(images) | |
| text_features = self.clip_model.get_text_features(texts) | |
| combined_features = (image_features + text_features) / 2 | |
| logits = self.fc(combined_features) | |
| return logits | |
| # Image preprocessing (resize and normalize) | |
| image_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| # EasyOCR Reader for text extraction | |
| # ocr_reader = easyocr.Reader(['en', 'de'], gpu=False) # Supports English and German | |
| ocr_reader = easyocr.Reader(['en', 'de'], model_storage_directory="./") | |
| # model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| num_classes = len(classes) | |
| model = CLIPMultimodalClassifier(num_classes) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Inference function | |
| def run_inference(image, model, clip_processor): | |
| # global model, clip_processor | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # # Initialize model if not already loaded | |
| # if model is None or clip_processor is None: | |
| # clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # model_path = hf_hub_download(repo_id="limitedonly41/offers_26", | |
| # filename="multi_train_best_model.pth", | |
| # use_auth_token=HF_TOKEN) | |
| # num_classes = len(classes) | |
| # model = CLIPMultimodalClassifier(num_classes).to(device) | |
| # model.load_state_dict(torch.load(model_path, map_location=device)) | |
| # model.eval() | |
| # clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # model_path = hf_hub_download(repo_id="limitedonly41/offers_26", | |
| # filename="multi_train_best_model.pth", | |
| # use_auth_token=HF_TOKEN) | |
| model_path = hf_hub_download(repo_id="limitedonly41/offers_26", | |
| filename="continued_training_model_5.pth", | |
| use_auth_token=HF_TOKEN) | |
| num_classes = len(classes) | |
| model = model.to(device) | |
| # model = CLIPMultimodalClassifier(num_classes).to(device) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| image_pil = Image.fromarray(image).convert("RGB") | |
| # Preprocess image | |
| image_tensor = image_transform(image_pil).unsqueeze(0).to(device) | |
| # Extract text using EasyOCR | |
| ocr_text = ocr_reader.readtext(image, detail=0) | |
| combined_text = " ".join(ocr_text) # Join OCR results into one string | |
| # Preprocess text for CLIP | |
| text_inputs = clip_processor( | |
| text=[combined_text], # Text in a list | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=77 | |
| ).to(device) | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = model(image_tensor, text_inputs["input_ids"]) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| predicted_class_idx = torch.argmax(probabilities, dim=1).item() | |
| # Return results | |
| predicted_class = classes[predicted_class_idx] | |
| return f"Predicted Class: {predicted_class}\nExtracted Text: {combined_text}" | |
| # Create a Gradio interface | |
| iface = gr.Interface( | |
| fn=lambda image: run_inference(image, model, clip_processor), | |
| inputs=gr.Image(type="numpy"), # Updated to use gr.Image | |
| outputs="text", # Output is text (predicted class) | |
| title="Image Classification", | |
| description="Upload an image to get the predicted class using the ViT model." | |
| ) | |
| # Launch the Gradio app | |
| iface.launch() |