import os import zipfile import tempfile import requests import numpy as np import pandas as pd from PIL import Image import torch import torch.nn.functional as F from torchvision import transforms from torchvision.models import resnet50, ResNet50_Weights from sklearn.cluster import MiniBatchKMeans import matplotlib.pyplot as plt import io import gradio as gr # Face analysis from deepface import DeepFace import cv2 # --------------------------- # Force CPU if no CUDA # --------------------------- if not torch.cuda.is_available(): os.environ["CUDA_VISIBLE_DEVICES"] = "" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --------------------------- # Load ResNet50 # --------------------------- weights = ResNet50_Weights.DEFAULT model = resnet50(weights=weights).to(device) model.eval() # --------------------------- # Transformations # --------------------------- transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # --------------------------- # ImageNet labels # --------------------------- LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()] # --------------------------- # Basic color utilities # --------------------------- BASIC_COLORS = { "Red": (255, 0, 0), "Green": (0, 255, 0), "Blue": (0, 0, 255), "Yellow": (255, 255, 0), "Cyan": (0, 255, 255), "Magenta": (255, 0, 255), "Black": (0, 0, 0), "White": (255, 255, 255), "Gray": (128, 128, 128), } def closest_basic_color(rgb): r, g, b = rgb min_dist = float("inf") closest_color = None for name, (cr, cg, cb) in BASIC_COLORS.items(): dist = (r - cr) ** 2 + (g - cg) ** 2 + (b - cb) ** 2 if dist < min_dist: min_dist = dist closest_color = name return closest_color def get_dominant_color(image, num_colors=5): image = image.resize((300, 300)) pixels = np.array(image).reshape(-1, 3) kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5) kmeans.fit(pixels) dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))] dominant_color = tuple(dominant_color.astype(int)) hex_color = f"#{dominant_color[0]:02x}{dominant_color[1]:02x}{dominant_color[2]:02x}" return dominant_color, hex_color # --------------------------- # Core function # --------------------------- def classify_zip_and_analyze_color(zip_file): results = [] thumbnails = [] # Name XLSX after zip zip_basename = os.path.splitext(os.path.basename(zip_file.name))[0] with tempfile.TemporaryDirectory() as tmpdir: with zipfile.ZipFile(zip_file.name, 'r') as zip_ref: zip_ref.extractall(tmpdir) for fname in sorted(os.listdir(tmpdir)): if fname.lower().endswith(('.png', '.jpg', '.jpeg')): img_path = os.path.join(tmpdir, fname) try: image = Image.open(img_path).convert("RGB") except Exception: continue # Thumbnail for gallery (higher-quality) thumb = image.copy() thumb = thumb.resize((200, 200), Image.LANCZOS) thumbnails.append(thumb) # Classification input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) probs = F.softmax(output, dim=1)[0] top3_prob, top3_idx = torch.topk(probs, 3) preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx, prob in zip(top3_idx, top3_prob)] # Dominant color rgb, hex_color = get_dominant_color(image) basic_color = closest_basic_color(rgb) # Face detection & characterization faces_data = [] try: img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) detected_faces = DeepFace.analyze( img_cv2, actions=["age", "gender", "emotion"], enforce_detection=False ) if isinstance(detected_faces, list): for f in detected_faces: faces_data.append({ "age": f["age"], "gender": f["gender"], "emotion": f["dominant_emotion"] }) else: faces_data.append({ "age": detected_faces["age"], "gender": detected_faces["gender"], "emotion": detected_faces["dominant_emotion"] }) except Exception: faces_data = [] results.append(( fname, ", ".join([p[0] for p in preds]), ", ".join([p[1] for p in preds]), hex_color, basic_color, faces_data )) # Build dataframe df = pd.DataFrame(results, columns=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"]) # Save XLSX out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_basename}_results.xlsx") df.to_excel(out_xlsx, index=False) # --------------------------- # Plots # --------------------------- # 1. Basic color frequency fig1, ax1 = plt.subplots() color_counts = df["Basic Color"].value_counts() ax1.bar(color_counts.index, color_counts.values, color="skyblue") ax1.set_title("Basic Color Frequency") ax1.set_ylabel("Count") buf1 = io.BytesIO() plt.savefig(buf1, format="png") plt.close(fig1) buf1.seek(0) plot1_img = Image.open(buf1) # 2. Top prediction distribution fig2, ax2 = plt.subplots() preds_flat = [] for p in df["Top 3 Predictions"]: preds_flat.extend(p.split(", ")) pred_counts = pd.Series(preds_flat).value_counts().head(20) ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon") ax2.set_title("Top Prediction Distribution") ax2.set_xlabel("Count") buf2 = io.BytesIO() plt.savefig(buf2, format="png", bbox_inches="tight") plt.close(fig2) buf2.seek(0) plot2_img = Image.open(buf2) # 3. Gender distribution (weighted) ages = [] gender_confidence = {"Man": 0, "Woman": 0} for face_list in df["Face Info"]: for face in face_list: ages.append(face["age"]) gender_dict = face["gender"] gender = max(gender_dict, key=gender_dict.get) conf = float(gender_dict[gender]) / 100 weight = min(conf, 0.9) if gender in gender_confidence: gender_confidence[gender] += weight else: gender_confidence[gender] = weight fig3, ax3 = plt.subplots() ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue", "pink"]) ax3.set_title("Gender Distribution (Weighted ≤90%)") ax3.set_ylabel("Sum of Confidence") buf3 = io.BytesIO() plt.savefig(buf3, format="png") plt.close(fig3) buf3.seek(0) plot3_img = Image.open(buf3) # 4. Age distribution by gender ages_men = [] ages_women = [] for face_list in df["Face Info"]: for face in face_list: age = face["age"] gender_dict = face["gender"] gender = max(gender_dict, key=gender_dict.get) if gender.lower() == "man": ages_men.append(age) else: ages_women.append(age) fig4, ax4 = plt.subplots() bins = range(0, 101, 5) ax4.hist([ages_men, ages_women], bins=bins, color=["lightblue", "pink"], label=["Men", "Women"], stacked=False) ax4.set_title("Age Distribution by Gender") ax4.set_xlabel("Age") ax4.set_ylabel("Count") ax4.legend() buf4 = io.BytesIO() plt.savefig(buf4, format="png") plt.close(fig4) buf4.seek(0) plot4_img = Image.open(buf4) return df, out_xlsx, thumbnails, plot1_img, plot2_img, plot3_img, plot4_img # --------------------------- # Gradio Interface # --------------------------- demo = gr.Interface( fn=classify_zip_and_analyze_color, inputs=gr.File(file_types=[".zip"], label="Upload ZIP of images"), outputs=[ gr.Dataframe(headers=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"]), gr.File(label="Download XLSX"), gr.Gallery(label="Thumbnails", show_label=True, elem_id="thumbnail-gallery", columns=5), gr.Image(type="pil", label="Basic Color Frequency"), gr.Image(type="pil", label="Top Prediction Distribution"), gr.Image(type="pil", label="Gender Distribution (Weighted ≤90%)"), gr.Image(type="pil", label="Age Distribution by Gender"), ], title="Image Classifier with Color & Face Analysis", description="Upload a ZIP of images. Classifies images, analyzes dominant color, detects faces, and displays thumbnails.", ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)