Ni3SinghR commited on
Commit
d4e1911
·
verified ·
1 Parent(s): 08d21ad

Upload 4 files

Browse files
Files changed (4) hide show
  1. ViT-B-32.pt +3 -0
  2. app.py +72 -0
  3. main.py +121 -0
  4. requirements.txt +93 -0
ViT-B-32.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af
3
+ size 353976522
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import requests
4
+ from PIL import Image
5
+ import io
6
+
7
+ # --- Configuration ---
8
+ BACKEND_URL = "http://127.0.0.1:8000/predict"
9
+
10
+ # --- Interface Logic ---
11
+ def predict_gender(image):
12
+ """
13
+ Sends an image to the FastAPI backend and returns the prediction.
14
+ 'image' is a NumPy array from the Gradio Image component.
15
+ """
16
+ if image is None:
17
+ raise gr.Error("Please upload an image first.")
18
+
19
+ try:
20
+ # Convert numpy array to bytes
21
+ pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
22
+ img_byte_arr = io.BytesIO()
23
+ pil_image.save(img_byte_arr, format='PNG')
24
+ img_byte_arr.seek(0) # Move cursor to the beginning of the buffer
25
+
26
+ # Prepare the file for the POST request
27
+ files = {'file': ('image.png', img_byte_arr, 'image/png')}
28
+
29
+ # Send request to the backend
30
+ response = requests.post(BACKEND_URL, files=files, timeout=30)
31
+
32
+ # Process the response
33
+ if response.status_code == 200:
34
+ return response.json()
35
+ else:
36
+ # Display error from the backend as a Gradio error
37
+ error_detail = response.json().get('detail', 'An unknown error occurred.')
38
+ raise gr.Error(f"API Error: {error_detail}")
39
+
40
+ except requests.exceptions.RequestException as e:
41
+ raise gr.Error(f"Could not connect to the backend. Please ensure the backend is running. Details: {e}")
42
+ except Exception as e:
43
+ raise gr.Error(f"An unexpected error occurred: {e}")
44
+
45
+
46
+ # --- Gradio Interface Definition ---
47
+ iface = gr.Interface(
48
+ fn=predict_gender,
49
+ inputs=gr.Image(label="Upload a Photo", type="numpy"),
50
+ outputs=gr.Label(label="Gender Prediction", num_top_classes=2),
51
+ title="📸 Gender Prediction with CLIP",
52
+ description=(
53
+ "Upload a clear, front-facing photo of a single person to predict their gender. "
54
+ "The app uses a backend API powered by OpenAI's CLIP model."
55
+ ),
56
+ examples=[
57
+ ["examples/male_example.jpg"],
58
+ ["examples/female_example.jpg"],
59
+ ],
60
+ allow_flagging="never",
61
+ css=".gradio-container {max-width: 780px !important; margin: auto;}"
62
+ )
63
+
64
+ # --- Launch the App ---
65
+ if __name__ == "__main__":
66
+ # Create an 'examples' directory for Gradio examples if it doesn't exist
67
+ import os
68
+ if not os.path.exists("examples"):
69
+ os.makedirs("examples")
70
+ print("Created 'examples' directory. Please add 'male_example.jpg' and 'female_example.jpg' for the demo.")
71
+
72
+ iface.launch()
main.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import uvicorn
3
+ import numpy as np
4
+ import clip
5
+ import torch
6
+ from fastapi import FastAPI, File, UploadFile, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from retinaface import RetinaFace
9
+ from PIL import Image
10
+ import io
11
+ import os
12
+
13
+ # --- Constants & Configuration ---
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ MODELS_DIR = "models"
16
+ GENDER_PROMPTS = ["a photo of a man", "a photo of a woman"]
17
+
18
+ # --- Error Messages ---
19
+ ERROR_MESSAGES = {
20
+ "NO_FACE": "No face detected. Please upload a clear, front-facing picture of a single person.",
21
+ "MULTIPLE_FACES": "Multiple faces detected. Please upload an image with only one face.",
22
+ "ANALYSIS_ERROR": "An unexpected error occurred during analysis. Please try again.",
23
+ "FILE_READ_ERROR": "Could not read the uploaded file. Please ensure it's a valid image."
24
+ }
25
+
26
+ # --- Model Loading ---
27
+ # Create models directory if it doesn't exist
28
+ os.makedirs(MODELS_DIR, exist_ok=True)
29
+
30
+ try:
31
+ print(f"Loading CLIP model on device: {DEVICE}...")
32
+ # Load the model, downloading to the specified directory if necessary
33
+ model, preprocess = clip.load("ViT-B/32", device=DEVICE, download_root=MODELS_DIR)
34
+ print("✓ CLIP model loaded successfully.")
35
+ except Exception as e:
36
+ print(f"✗ Failed to load CLIP model: {e}")
37
+ exit()
38
+
39
+ # --- FastAPI App Initialization ---
40
+ app = FastAPI(
41
+ title="Gender Detection API",
42
+ description="A simple API using CLIP to predict gender from an image."
43
+ )
44
+
45
+ app.add_middleware(
46
+ CORSMiddleware,
47
+ allow_origins=["*"], # Allows all origins for simplicity
48
+ allow_credentials=True,
49
+ allow_methods=["*"],
50
+ allow_headers=["*"],
51
+ )
52
+
53
+ # --- Core Logic ---
54
+ def predict_gender_with_clip(image: Image.Image) -> dict:
55
+ """
56
+ Predicts gender from a PIL Image using the loaded CLIP model.
57
+
58
+ Args:
59
+ image (Image.Image): The input image.
60
+
61
+ Returns:
62
+ dict: A dictionary with gender labels and their confidence scores.
63
+ """
64
+ image_input = preprocess(image).unsqueeze(0).to(DEVICE)
65
+ text_inputs = clip.tokenize(GENDER_PROMPTS).to(DEVICE)
66
+
67
+ with torch.no_grad():
68
+ logits_per_image, _ = model(image_input, text_inputs)
69
+ # Softmax to get probabilities
70
+ probabilities = logits_per_image.softmax(dim=-1).cpu().numpy()[0]
71
+
72
+ # Map probabilities to labels
73
+ return {GENDER_PROMPTS[i].split("of a ")[-1]: float(prob) for i, prob in enumerate(probabilities)}
74
+
75
+
76
+ # --- API Endpoints ---
77
+ @app.get("/health")
78
+ async def health_check():
79
+ """Health check endpoint to verify if the API is running."""
80
+ return {"status": "healthy"}
81
+
82
+ @app.post("/predict")
83
+ async def predict(file: UploadFile = File(...)):
84
+ """
85
+ Main prediction endpoint. It validates the image and returns gender probabilities.
86
+ """
87
+ try:
88
+ # 1. Read and validate the uploaded image
89
+ contents = await file.read()
90
+ image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
91
+ # Convert to numpy array for face detection (expects BGR)
92
+ image_np = np.array(image_pil)
93
+ image_np = image_np[:, :, ::-1].copy() # RGB -> BGR
94
+ except Exception:
95
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES["FILE_READ_ERROR"])
96
+
97
+ try:
98
+ # 2. Detect faces using RetinaFace
99
+ faces = RetinaFace.detect_faces(image_np)
100
+ num_faces = len(faces)
101
+
102
+ if num_faces == 0:
103
+ raise HTTPException(status_code=422, detail=ERROR_MESSAGES["NO_FACE"])
104
+ if num_faces > 1:
105
+ raise HTTPException(status_code=422, detail=ERROR_MESSAGES["MULTIPLE_FACES"])
106
+
107
+ # 3. Predict gender using CLIP
108
+ gender_probabilities = predict_gender_with_clip(image_pil)
109
+
110
+ return gender_probabilities
111
+
112
+ except HTTPException as e:
113
+ # Re-raise known HTTP exceptions
114
+ raise e
115
+ except Exception as e:
116
+ print(f"An unexpected error occurred: {e}")
117
+ raise HTTPException(status_code=500, detail=ERROR_MESSAGES["ANALYSIS_ERROR"])
118
+
119
+ # --- Main Execution ---
120
+ if __name__ == "__main__":
121
+ uvicorn.run(app, host="127.0.0.1", port=8000)
requirements.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ astunparse==1.6.3
5
+ attrs==25.3.0
6
+ beautifulsoup4==4.13.4
7
+ cachetools==5.5.2
8
+ certifi==2025.6.15
9
+ cffi==1.17.1
10
+ charset-normalizer==3.4.2
11
+ click==8.2.1
12
+ clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
13
+ colorama==0.4.6
14
+ contourpy==1.3.2
15
+ cycler==0.12.1
16
+ exceptiongroup==1.3.0
17
+ fastapi==0.115.13
18
+ filelock==3.18.0
19
+ flatbuffers==25.2.10
20
+ fonttools==4.58.4
21
+ fsspec==2025.5.1
22
+ ftfy==6.3.1
23
+ gast==0.6.0
24
+ gdown==5.2.0
25
+ google-auth==2.40.3
26
+ google-auth-oauthlib==1.0.0
27
+ google-pasta==0.2.0
28
+ grpcio==1.73.1
29
+ h11==0.16.0
30
+ h5py==3.14.0
31
+ idna==3.10
32
+ jax==0.4.34
33
+ jaxlib==0.4.34
34
+ jinja2==3.1.6
35
+ keras==2.14.0
36
+ kiwisolver==1.4.8
37
+ libclang==18.1.1
38
+ markdown==3.8.2
39
+ markupsafe==3.0.2
40
+ matplotlib==3.10.3
41
+ mediapipe==0.10.14
42
+ ml-dtypes==0.2.0
43
+ mpmath==1.3.0
44
+ networkx==3.4.2
45
+ numpy==1.24.4
46
+ oauthlib==3.3.1
47
+ opencv-contrib-python==4.11.0.86
48
+ opencv-python==4.11.0.86
49
+ opencv-python-headless==4.11.0.86
50
+ opt-einsum==3.4.0
51
+ packaging==25.0
52
+ pillow==11.2.1
53
+ protobuf==4.25.8
54
+ pyasn1==0.6.1
55
+ pyasn1-modules==0.4.2
56
+ pycparser==2.22
57
+ pydantic==2.11.7
58
+ pydantic-core==2.33.2
59
+ pyparsing==3.2.3
60
+ pysocks==1.7.1
61
+ python-dateutil==2.9.0.post0
62
+ regex==2024.11.6
63
+ requests==2.32.4
64
+ requests-oauthlib==2.0.0
65
+ retina-face==0.0.17
66
+ rsa==4.9.1
67
+ scipy==1.15.3
68
+ setuptools==80.9.0
69
+ six==1.17.0
70
+ sniffio==1.3.1
71
+ sounddevice==0.5.2
72
+ soupsieve==2.7
73
+ starlette==0.46.2
74
+ sympy==1.14.0
75
+ tensorboard==2.14.1
76
+ tensorboard-data-server==0.7.2
77
+ tensorflow==2.14.0
78
+ tensorflow-estimator==2.14.0
79
+ tensorflow-intel==2.14.0
80
+ tensorflow-io-gcs-filesystem==0.31.0
81
+ termcolor==3.1.0
82
+ torch==2.7.1
83
+ torchaudio==2.7.1
84
+ torchvision==0.22.1
85
+ tqdm==4.67.1
86
+ typing-extensions==4.14.0
87
+ typing-inspection==0.4.1
88
+ urllib3==2.5.0
89
+ uvicorn==0.34.3
90
+ wcwidth==0.2.13
91
+ werkzeug==3.1.3
92
+ wheel==0.45.1
93
+ wrapt==1.14.1