Spaces:
Sleeping
Sleeping
| import os | |
| import zipfile | |
| import tempfile | |
| import requests | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| from sklearn.cluster import MiniBatchKMeans | |
| import matplotlib.pyplot as plt | |
| import io | |
| import gradio as gr | |
| # Face analysis | |
| from deepface import DeepFace | |
| import cv2 | |
| # --------------------------- | |
| # Force CPU if no CUDA | |
| # --------------------------- | |
| if not torch.cuda.is_available(): | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --------------------------- | |
| # Load ResNet50 | |
| # --------------------------- | |
| weights = ResNet50_Weights.DEFAULT | |
| model = resnet50(weights=weights).to(device) | |
| model.eval() | |
| # --------------------------- | |
| # Transformations | |
| # --------------------------- | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # --------------------------- | |
| # ImageNet labels | |
| # --------------------------- | |
| LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" | |
| imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()] | |
| # --------------------------- | |
| # Basic color utilities | |
| # --------------------------- | |
| BASIC_COLORS = { | |
| "Red": (255, 0, 0), | |
| "Green": (0, 255, 0), | |
| "Blue": (0, 0, 255), | |
| "Yellow": (255, 255, 0), | |
| "Cyan": (0, 255, 255), | |
| "Magenta": (255, 0, 255), | |
| "Black": (0, 0, 0), | |
| "White": (255, 255, 255), | |
| "Gray": (128, 128, 128), | |
| } | |
| def closest_basic_color(rgb): | |
| r, g, b = rgb | |
| min_dist = float("inf") | |
| closest_color = None | |
| for name, (cr, cg, cb) in BASIC_COLORS.items(): | |
| dist = (r - cr) ** 2 + (g - cg) ** 2 + (b - cb) ** 2 | |
| if dist < min_dist: | |
| min_dist = dist | |
| closest_color = name | |
| return closest_color | |
| def get_dominant_color(image, num_colors=5): | |
| image = image.resize((300, 300)) | |
| pixels = np.array(image).reshape(-1, 3) | |
| kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5) | |
| kmeans.fit(pixels) | |
| dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))] | |
| dominant_color = tuple(dominant_color.astype(int)) | |
| hex_color = f"#{dominant_color[0]:02x}{dominant_color[1]:02x}{dominant_color[2]:02x}" | |
| return dominant_color, hex_color | |
| # --------------------------- | |
| # Core function | |
| # --------------------------- | |
| def classify_zip_and_analyze_color(zip_file): | |
| results = [] | |
| thumbnails = [] | |
| # Name XLSX after zip | |
| zip_basename = os.path.splitext(os.path.basename(zip_file.name))[0] | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| with zipfile.ZipFile(zip_file.name, 'r') as zip_ref: | |
| zip_ref.extractall(tmpdir) | |
| for fname in sorted(os.listdir(tmpdir)): | |
| if fname.lower().endswith(('.png', '.jpg', '.jpeg')): | |
| img_path = os.path.join(tmpdir, fname) | |
| try: | |
| image = Image.open(img_path).convert("RGB") | |
| except Exception: | |
| continue | |
| # Thumbnail for gallery (higher-quality) | |
| thumb = image.copy() | |
| thumb = thumb.resize((200, 200), Image.LANCZOS) | |
| thumbnails.append(thumb) | |
| # Classification | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| probs = F.softmax(output, dim=1)[0] | |
| top3_prob, top3_idx = torch.topk(probs, 3) | |
| preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx, prob in zip(top3_idx, top3_prob)] | |
| # Dominant color | |
| rgb, hex_color = get_dominant_color(image) | |
| basic_color = closest_basic_color(rgb) | |
| # Face detection & characterization | |
| faces_data = [] | |
| try: | |
| img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| detected_faces = DeepFace.analyze( | |
| img_cv2, actions=["age", "gender", "emotion"], enforce_detection=False | |
| ) | |
| if isinstance(detected_faces, list): | |
| for f in detected_faces: | |
| faces_data.append({ | |
| "age": f["age"], | |
| "gender": f["gender"], | |
| "emotion": f["dominant_emotion"] | |
| }) | |
| else: | |
| faces_data.append({ | |
| "age": detected_faces["age"], | |
| "gender": detected_faces["gender"], | |
| "emotion": detected_faces["dominant_emotion"] | |
| }) | |
| except Exception: | |
| faces_data = [] | |
| results.append(( | |
| fname, | |
| ", ".join([p[0] for p in preds]), | |
| ", ".join([p[1] for p in preds]), | |
| hex_color, | |
| basic_color, | |
| faces_data | |
| )) | |
| # Build dataframe | |
| df = pd.DataFrame(results, columns=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"]) | |
| # Save XLSX | |
| out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_basename}_results.xlsx") | |
| df.to_excel(out_xlsx, index=False) | |
| # --------------------------- | |
| # Plots | |
| # --------------------------- | |
| # 1. Basic color frequency | |
| fig1, ax1 = plt.subplots() | |
| color_counts = df["Basic Color"].value_counts() | |
| ax1.bar(color_counts.index, color_counts.values, color="skyblue") | |
| ax1.set_title("Basic Color Frequency") | |
| ax1.set_ylabel("Count") | |
| buf1 = io.BytesIO() | |
| plt.savefig(buf1, format="png") | |
| plt.close(fig1) | |
| buf1.seek(0) | |
| plot1_img = Image.open(buf1) | |
| # 2. Top prediction distribution | |
| fig2, ax2 = plt.subplots() | |
| preds_flat = [] | |
| for p in df["Top 3 Predictions"]: | |
| preds_flat.extend(p.split(", ")) | |
| pred_counts = pd.Series(preds_flat).value_counts().head(20) | |
| ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon") | |
| ax2.set_title("Top Prediction Distribution") | |
| ax2.set_xlabel("Count") | |
| buf2 = io.BytesIO() | |
| plt.savefig(buf2, format="png", bbox_inches="tight") | |
| plt.close(fig2) | |
| buf2.seek(0) | |
| plot2_img = Image.open(buf2) | |
| # 3. Gender distribution (weighted) | |
| ages = [] | |
| gender_confidence = {"Man": 0, "Woman": 0} | |
| for face_list in df["Face Info"]: | |
| for face in face_list: | |
| ages.append(face["age"]) | |
| gender_dict = face["gender"] | |
| gender = max(gender_dict, key=gender_dict.get) | |
| conf = float(gender_dict[gender]) / 100 | |
| weight = min(conf, 0.9) | |
| if gender in gender_confidence: | |
| gender_confidence[gender] += weight | |
| else: | |
| gender_confidence[gender] = weight | |
| fig3, ax3 = plt.subplots() | |
| ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue", "pink"]) | |
| ax3.set_title("Gender Distribution (Weighted ≤90%)") | |
| ax3.set_ylabel("Sum of Confidence") | |
| buf3 = io.BytesIO() | |
| plt.savefig(buf3, format="png") | |
| plt.close(fig3) | |
| buf3.seek(0) | |
| plot3_img = Image.open(buf3) | |
| # 4. Age distribution by gender | |
| ages_men = [] | |
| ages_women = [] | |
| for face_list in df["Face Info"]: | |
| for face in face_list: | |
| age = face["age"] | |
| gender_dict = face["gender"] | |
| gender = max(gender_dict, key=gender_dict.get) | |
| if gender.lower() == "man": | |
| ages_men.append(age) | |
| else: | |
| ages_women.append(age) | |
| fig4, ax4 = plt.subplots() | |
| bins = range(0, 101, 5) | |
| ax4.hist([ages_men, ages_women], bins=bins, color=["lightblue", "pink"], label=["Men", "Women"], stacked=False) | |
| ax4.set_title("Age Distribution by Gender") | |
| ax4.set_xlabel("Age") | |
| ax4.set_ylabel("Count") | |
| ax4.legend() | |
| buf4 = io.BytesIO() | |
| plt.savefig(buf4, format="png") | |
| plt.close(fig4) | |
| buf4.seek(0) | |
| plot4_img = Image.open(buf4) | |
| return df, out_xlsx, thumbnails, plot1_img, plot2_img, plot3_img, plot4_img | |
| # --------------------------- | |
| # Gradio Interface | |
| # --------------------------- | |
| demo = gr.Interface( | |
| fn=classify_zip_and_analyze_color, | |
| inputs=gr.File(file_types=[".zip"], label="Upload ZIP of images"), | |
| outputs=[ | |
| gr.Dataframe(headers=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"]), | |
| gr.File(label="Download XLSX"), | |
| gr.Gallery(label="Thumbnails", show_label=True, elem_id="thumbnail-gallery", columns=5), | |
| gr.Image(type="pil", label="Basic Color Frequency"), | |
| gr.Image(type="pil", label="Top Prediction Distribution"), | |
| gr.Image(type="pil", label="Gender Distribution (Weighted ≤90%)"), | |
| gr.Image(type="pil", label="Age Distribution by Gender"), | |
| ], | |
| title="Image Classifier with Color & Face Analysis", | |
| description="Upload a ZIP of images. Classifies images, analyzes dominant color, detects faces, and displays thumbnails.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |