Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| import os | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline | |
| import torch | |
| import gc | |
| # Set device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load the main classifier (Main_Classifier_best_model.pth) | |
| main_model = models.resnet18(weights=None) # Updated: weights=None | |
| num_ftrs = main_model.fc.in_features | |
| main_model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Soda drinks, Clothing, Mobile Phones | |
| main_model.load_state_dict(torch.load('Main_Classifier_best_model.pth', map_location=device, weights_only=True)) # Updated: weights_only=True | |
| main_model = main_model.to(device) | |
| main_model.eval() | |
| # Define class names for the main classifier based on folder structure | |
| main_class_names = ['Clothing', 'Mobile Phones', 'Soda drinks'] | |
| # Sub-classifier models | |
| def load_soda_drinks_model(): | |
| model = models.resnet18(weights=None) # Updated: weights=None | |
| num_ftrs = model.fc.in_features | |
| model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Miranda, Pepsi, Seven Up | |
| model.load_state_dict(torch.load('Soda_drinks_best_model.pth', map_location=device, weights_only=True)) # Updated | |
| model = model.to(device) | |
| model.eval() | |
| return model | |
| def load_clothing_model(): | |
| model = models.resnet18(weights=None) # Updated: weights=None | |
| num_ftrs = model.fc.in_features | |
| model.fc = nn.Linear(num_ftrs, 3) # 2 classes: Pants, T-Shirt | |
| model.load_state_dict(torch.load('Clothes_best_model.pth', map_location=device, weights_only=True)) # Updated | |
| model = model.to(device) | |
| model.eval() | |
| return model | |
| def load_mobile_phones_model(): | |
| model = models.resnet18(weights=None) # Updated: weights=None | |
| num_ftrs = model.fc.in_features | |
| model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Apple, Samsung | |
| model.load_state_dict(torch.load('Phone_best_model.pth', map_location=device, weights_only=True)) # Updated | |
| model = model.to(device) | |
| model.eval() | |
| return model | |
| def convert_to_rgb(image): | |
| """ | |
| Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'. | |
| This is to avoid transparency issues during model training. | |
| """ | |
| if image.mode in ('P', 'RGBA'): | |
| return image.convert('RGB') | |
| return image | |
| # Define preprocessing transformations (same used during training) | |
| preprocess = transforms.Compose([ | |
| transforms.Lambda(convert_to_rgb), | |
| transforms.Resize((224, 224)), # Resize here, no need for shape argument in gr.Image | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization | |
| ]) | |
| # Load Meta's LLaMA model for generating product descriptions | |
| def load_llama(): | |
| model_name = "meta-llama/Llama-3.2-1B-Instruct" | |
| token = os.getenv("HUGGINGFACE_TOKEN") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=token).to(device) | |
| # Initialize the text generation pipeline with the prepared model | |
| text_generation = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer | |
| ) | |
| return tokenizer, model | |
| llama_tokenizer, llama_model = load_llama() | |
| # Generate product description using external data and structured format | |
| def generate_description(category, subclass): | |
| # Define file path and read content | |
| file_path = 'data for product description.txt' | |
| with open(file_path, 'r', encoding='utf-8') as file: | |
| file_content = file.read() | |
| prompt = f""" | |
| [Data] | |
| {file_content} | |
| Role: You are a product description content writer with 10 years of experience in the market. Generate a product description for a {subclass} in the {category} category based on the [Data] provided. | |
| Follow the [Instructions] strictly: | |
| [Instructions] | |
| - Create a detailed product description for a {subclass} in the {category} category based on the [Data]. | |
| - Use the structured format below, making each section clear and concise. | |
| - Highlight key product features, technical specifications, and the target audience. | |
| """ | |
| generated_texts = llama_model.generate( | |
| inputs=llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device), | |
| max_length=7000, | |
| max_new_tokens=2000, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95, | |
| ) | |
| description = llama_tokenizer.decode(generated_texts[0], skip_special_tokens=True) | |
| # Clean up resources | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return description | |
| # # Generate product description using LLaMA | |
| # def generate_description(category, subclass): | |
| # prompt = f"Generate a detailed and engaging product description for a {category} of type {subclass}." | |
| # inputs = llama_tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| # outputs = llama_model.generate(inputs, max_length=100, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) | |
| # description = llama_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # return description | |
| def classify_image(image): | |
| # Open the image using PIL | |
| image = Image.fromarray(image) | |
| # Preprocess the image | |
| input_image = preprocess(image).unsqueeze(0).to(device) | |
| # Perform inference with the main classifier | |
| with torch.no_grad(): | |
| output = main_model(input_image) | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| confidence, predicted_class = torch.max(probabilities, 0) | |
| # Main classifier result | |
| main_prediction = main_class_names[predicted_class] | |
| main_confidence = confidence.item() | |
| if main_confidence <=0.90: | |
| main_prediction = 'Others' | |
| main_confidence = 100-main_confidence | |
| sub_prediction = "Undefined" | |
| sub_confidence = -100 | |
| description = None | |
| # Load and apply the sub-classifier based on the main classification | |
| if main_prediction in ['Clothing', 'Mobile Phones', 'Soda drinks']: | |
| if main_prediction == 'Soda drinks': | |
| soda_model = load_soda_drinks_model() | |
| sub_class_names = ['Miranda', 'Pepsi', 'Seven Up'] | |
| with torch.no_grad(): | |
| sub_output = soda_model(input_image) | |
| elif main_prediction == 'Clothing': | |
| clothing_model = load_clothing_model() | |
| sub_class_names = ['Pants', 'T-Shirt','others'] | |
| with torch.no_grad(): | |
| sub_output = clothing_model(input_image) | |
| elif main_prediction == 'Mobile Phones': | |
| phones_model = load_mobile_phones_model() | |
| sub_class_names = ['Apple', 'Samsung'] | |
| with torch.no_grad(): | |
| sub_output = phones_model(input_image) | |
| # Perform inference with the sub-classifier | |
| sub_probabilities = torch.nn.functional.softmax(sub_output[0], dim=0) | |
| sub_confidence, sub_predicted_class = torch.max(sub_probabilities, 0) | |
| sub_prediction = sub_class_names[sub_predicted_class] | |
| sub_confidence = sub_confidence.item() | |
| if sub_confidence < 0.90 : | |
| sub_prediction = "Others" | |
| sub_confidence = 100- sub_confidence | |
| description=None | |
| else: | |
| # Generate product description | |
| description = generate_description(main_prediction, sub_prediction) | |
| return f"Main Predicted Class: {main_prediction} (Confidence: {main_confidence:.4f})", \ | |
| f"Sub Predicted Class: {sub_prediction} (Confidence: {sub_confidence:.4f})", \ | |
| f"Product Description: {description}" | |
| # Gradio interface (updated) | |
| image_input = gr.Image(image_mode="RGB") # Removed shape argument | |
| output_text = gr.Textbox() | |
| gr.Interface(fn=classify_image, inputs=image_input, outputs=[output_text], | |
| title="Main and Sub-Classifier System product description ", | |
| description="Upload an image to classify whether it belongs to Clothing, Mobile Phones, or Soda Drinks. Based on the prediction, it will further classify within the subcategory and generate a detailed product description .", | |
| theme="default").launch() |