clementBE's picture
Update app.py
301006f verified
import os
import zipfile
import tempfile
import requests
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from sklearn.cluster import MiniBatchKMeans
import matplotlib.pyplot as plt
import io
import gradio as gr
# Face analysis
from deepface import DeepFace
import cv2
# ---------------------------
# Force CPU if no CUDA
# ---------------------------
if not torch.cuda.is_available():
os.environ["CUDA_VISIBLE_DEVICES"] = ""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---------------------------
# Load ResNet50
# ---------------------------
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights).to(device)
model.eval()
# ---------------------------
# Transformations
# ---------------------------
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ---------------------------
# ImageNet labels
# ---------------------------
LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()]
# ---------------------------
# Basic color utilities
# ---------------------------
BASIC_COLORS = {
"Red": (255, 0, 0),
"Green": (0, 255, 0),
"Blue": (0, 0, 255),
"Yellow": (255, 255, 0),
"Cyan": (0, 255, 255),
"Magenta": (255, 0, 255),
"Black": (0, 0, 0),
"White": (255, 255, 255),
"Gray": (128, 128, 128),
}
def closest_basic_color(rgb):
r, g, b = rgb
min_dist = float("inf")
closest_color = None
for name, (cr, cg, cb) in BASIC_COLORS.items():
dist = (r - cr) ** 2 + (g - cg) ** 2 + (b - cb) ** 2
if dist < min_dist:
min_dist = dist
closest_color = name
return closest_color
def get_dominant_color(image, num_colors=5):
image = image.resize((300, 300))
pixels = np.array(image).reshape(-1, 3)
kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5)
kmeans.fit(pixels)
dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
dominant_color = tuple(dominant_color.astype(int))
hex_color = f"#{dominant_color[0]:02x}{dominant_color[1]:02x}{dominant_color[2]:02x}"
return dominant_color, hex_color
# ---------------------------
# Core function
# ---------------------------
def classify_zip_and_analyze_color(zip_file):
results = []
thumbnails = []
# Name XLSX after zip
zip_basename = os.path.splitext(os.path.basename(zip_file.name))[0]
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
zip_ref.extractall(tmpdir)
for fname in sorted(os.listdir(tmpdir)):
if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(tmpdir, fname)
try:
image = Image.open(img_path).convert("RGB")
except Exception:
continue
# Thumbnail for gallery (higher-quality)
thumb = image.copy()
thumb = thumb.resize((200, 200), Image.LANCZOS)
thumbnails.append(thumb)
# Classification
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
probs = F.softmax(output, dim=1)[0]
top3_prob, top3_idx = torch.topk(probs, 3)
preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx, prob in zip(top3_idx, top3_prob)]
# Dominant color
rgb, hex_color = get_dominant_color(image)
basic_color = closest_basic_color(rgb)
# Face detection & characterization
faces_data = []
try:
img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
detected_faces = DeepFace.analyze(
img_cv2, actions=["age", "gender", "emotion"], enforce_detection=False
)
if isinstance(detected_faces, list):
for f in detected_faces:
faces_data.append({
"age": f["age"],
"gender": f["gender"],
"emotion": f["dominant_emotion"]
})
else:
faces_data.append({
"age": detected_faces["age"],
"gender": detected_faces["gender"],
"emotion": detected_faces["dominant_emotion"]
})
except Exception:
faces_data = []
results.append((
fname,
", ".join([p[0] for p in preds]),
", ".join([p[1] for p in preds]),
hex_color,
basic_color,
faces_data
))
# Build dataframe
df = pd.DataFrame(results, columns=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"])
# Save XLSX
out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_basename}_results.xlsx")
df.to_excel(out_xlsx, index=False)
# ---------------------------
# Plots
# ---------------------------
# 1. Basic color frequency
fig1, ax1 = plt.subplots()
color_counts = df["Basic Color"].value_counts()
ax1.bar(color_counts.index, color_counts.values, color="skyblue")
ax1.set_title("Basic Color Frequency")
ax1.set_ylabel("Count")
buf1 = io.BytesIO()
plt.savefig(buf1, format="png")
plt.close(fig1)
buf1.seek(0)
plot1_img = Image.open(buf1)
# 2. Top prediction distribution
fig2, ax2 = plt.subplots()
preds_flat = []
for p in df["Top 3 Predictions"]:
preds_flat.extend(p.split(", "))
pred_counts = pd.Series(preds_flat).value_counts().head(20)
ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
ax2.set_title("Top Prediction Distribution")
ax2.set_xlabel("Count")
buf2 = io.BytesIO()
plt.savefig(buf2, format="png", bbox_inches="tight")
plt.close(fig2)
buf2.seek(0)
plot2_img = Image.open(buf2)
# 3. Gender distribution (weighted)
ages = []
gender_confidence = {"Man": 0, "Woman": 0}
for face_list in df["Face Info"]:
for face in face_list:
ages.append(face["age"])
gender_dict = face["gender"]
gender = max(gender_dict, key=gender_dict.get)
conf = float(gender_dict[gender]) / 100
weight = min(conf, 0.9)
if gender in gender_confidence:
gender_confidence[gender] += weight
else:
gender_confidence[gender] = weight
fig3, ax3 = plt.subplots()
ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue", "pink"])
ax3.set_title("Gender Distribution (Weighted ≤90%)")
ax3.set_ylabel("Sum of Confidence")
buf3 = io.BytesIO()
plt.savefig(buf3, format="png")
plt.close(fig3)
buf3.seek(0)
plot3_img = Image.open(buf3)
# 4. Age distribution by gender
ages_men = []
ages_women = []
for face_list in df["Face Info"]:
for face in face_list:
age = face["age"]
gender_dict = face["gender"]
gender = max(gender_dict, key=gender_dict.get)
if gender.lower() == "man":
ages_men.append(age)
else:
ages_women.append(age)
fig4, ax4 = plt.subplots()
bins = range(0, 101, 5)
ax4.hist([ages_men, ages_women], bins=bins, color=["lightblue", "pink"], label=["Men", "Women"], stacked=False)
ax4.set_title("Age Distribution by Gender")
ax4.set_xlabel("Age")
ax4.set_ylabel("Count")
ax4.legend()
buf4 = io.BytesIO()
plt.savefig(buf4, format="png")
plt.close(fig4)
buf4.seek(0)
plot4_img = Image.open(buf4)
return df, out_xlsx, thumbnails, plot1_img, plot2_img, plot3_img, plot4_img
# ---------------------------
# Gradio Interface
# ---------------------------
demo = gr.Interface(
fn=classify_zip_and_analyze_color,
inputs=gr.File(file_types=[".zip"], label="Upload ZIP of images"),
outputs=[
gr.Dataframe(headers=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"]),
gr.File(label="Download XLSX"),
gr.Gallery(label="Thumbnails", show_label=True, elem_id="thumbnail-gallery", columns=5),
gr.Image(type="pil", label="Basic Color Frequency"),
gr.Image(type="pil", label="Top Prediction Distribution"),
gr.Image(type="pil", label="Gender Distribution (Weighted ≤90%)"),
gr.Image(type="pil", label="Age Distribution by Gender"),
],
title="Image Classifier with Color & Face Analysis",
description="Upload a ZIP of images. Classifies images, analyzes dominant color, detects faces, and displays thumbnails.",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)