Spaces:
Sleeping
Sleeping
| import os | |
| import zipfile | |
| import torch | |
| from torch import nn, optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import transforms | |
| from PIL import Image | |
| from transformers import CLIPModel, CLIPProcessor | |
| import gradio as gr | |
| # Step 1: Unzip the dataset | |
| if not os.path.exists("data"): | |
| os.makedirs("data") | |
| print("Extracting Data.zip...") | |
| with zipfile.ZipFile("Data.zip", 'r') as zip_ref: | |
| zip_ref.extractall("data") | |
| print("Extraction complete.") | |
| # Step 2: Dynamically find the 'safe' and 'unsafe' folders | |
| def find_dataset_path(root_dir): | |
| for root, dirs, files in os.walk(root_dir): | |
| if 'safe' in dirs and 'unsafe' in dirs: | |
| return root | |
| return None | |
| # Look for 'safe' and 'unsafe' inside 'data/Data' | |
| dataset_path = find_dataset_path("data/Data") | |
| if dataset_path is None: | |
| print("Debugging extracted structure:") | |
| for root, dirs, files in os.walk("data"): | |
| print(f"Root: {root}") | |
| print(f"Directories: {dirs}") | |
| print(f"Files: {files}") | |
| raise FileNotFoundError("Expected 'safe' and 'unsafe' folders not found inside 'data/Data'. Please check the Data.zip structure.") | |
| print(f"Dataset path found: {dataset_path}") | |
| # Step 3: Define Custom Dataset Class | |
| class CustomImageDataset(Dataset): | |
| def __init__(self, root_dir, transform=None): | |
| self.root_dir = root_dir | |
| self.transform = transform | |
| self.image_paths = [] | |
| self.labels = [] | |
| for label, folder in enumerate(["safe", "unsafe"]): # 0 = safe, 1 = unsafe | |
| folder_path = os.path.join(root_dir, folder) | |
| if not os.path.exists(folder_path): | |
| raise FileNotFoundError(f"Folder '{folder}' not found in '{root_dir}'") | |
| for filename in os.listdir(folder_path): | |
| if filename.endswith((".jpg", ".jpeg", ".png")): # Only load image files | |
| self.image_paths.append(os.path.join(folder_path, filename)) | |
| self.labels.append(label) | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image_path = self.image_paths[idx] | |
| image = Image.open(image_path).convert("RGB") | |
| label = self.labels[idx] | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| # Step 4: Data Transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # Resize to 224x224 pixels | |
| transforms.ToTensor(), # Convert to tensor | |
| transforms.Normalize((0.5,), (0.5,)), # Normalize image values | |
| ]) | |
| # Step 5: Load the Dataset | |
| train_dataset = CustomImageDataset(dataset_path, transform=transform) | |
| train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) | |
| # Debugging: Check the dataset | |
| print(f"Number of samples in the dataset: {len(train_dataset)}") | |
| if len(train_dataset) == 0: | |
| raise ValueError("The dataset is empty. Please check if 'Data.zip' is correctly unzipped and contains 'safe' and 'unsafe' folders.") | |
| # Step 6: Load Pretrained CLIP Model | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Add a Classification Layer | |
| model.classifier = nn.Linear(model.visual_projection.out_features, 2) # 2 classes: safe, unsafe | |
| # Define Optimizer and Loss Function | |
| optimizer = optim.Adam(model.classifier.parameters(), lr=1e-4) | |
| criterion = nn.CrossEntropyLoss() | |
| # Step 7: Fine-Tune the Model | |
| model.train() | |
| for epoch in range(3): # Number of epochs | |
| total_loss = 0 | |
| for images, labels in train_loader: | |
| optimizer.zero_grad() | |
| images = torch.stack([img.to(torch.float32) for img in images]) # Batch of images | |
| outputs = model.get_image_features(pixel_values=images) | |
| logits = model.classifier(outputs) | |
| loss = criterion(logits, labels) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}") | |
| # Save the Fine-Tuned Model | |
| model.save_pretrained("fine-tuned-model") | |
| processor.save_pretrained("fine-tuned-model") | |
| print("Model fine-tuned and saved successfully.") | |
| # Step 8: Define Gradio Inference Function | |
| def classify_image(image, class_names): | |
| # Load Fine-Tuned Model | |
| model = CLIPModel.from_pretrained("fine-tuned-model") | |
| processor = CLIPProcessor.from_pretrained("fine-tuned-model") | |
| # Split class names from comma-separated input | |
| labels = [label.strip() for label in class_names.split(",") if label.strip()] | |
| if not labels: | |
| return {"Error": "Please enter at least one valid class name."} | |
| # Process the image and labels | |
| inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) | |
| outputs = model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities | |
| # Extract labels with their corresponding probabilities | |
| result = {label: probs[0][i].item() for i, label in enumerate(labels)} | |
| return dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) | |
| # Step 9: Set Up Gradio Interface | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=[ | |
| gr.Image(type="pil"), | |
| gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe") | |
| ], | |
| outputs=gr.Label(num_top_classes=2), | |
| title="Content Safety Classification", | |
| description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.", | |
| ) | |
| # Launch Gradio Interface | |
| if __name__ == "__main__": | |
| iface.launch() | |