Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,20 +34,27 @@ def classify_image(image):
|
|
| 34 |
# Run inference
|
| 35 |
outputs = model(**inputs)
|
| 36 |
|
| 37 |
-
# Extract logits
|
| 38 |
-
logits_per_image = outputs.logits_per_image #
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# Extract probabilities for each category
|
| 42 |
-
safe_prob = probs[0][0] #
|
| 43 |
-
unsafe_prob = probs[0][1] #
|
|
|
|
| 44 |
|
| 45 |
# Normalize probabilities to ensure they sum to 100%
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
-
#
|
|
|
|
| 51 |
return {
|
| 52 |
"safe": round(safe_percentage, 2), # Rounded to 2 decimal places
|
| 53 |
"unsafe": round(unsafe_percentage, 2)
|
|
@@ -58,6 +65,7 @@ def classify_image(image):
|
|
| 58 |
|
| 59 |
|
| 60 |
|
|
|
|
| 61 |
# Step 3: Set Up Gradio Interface
|
| 62 |
iface = gr.Interface(
|
| 63 |
fn=classify_image,
|
|
|
|
| 34 |
# Run inference
|
| 35 |
outputs = model(**inputs)
|
| 36 |
|
| 37 |
+
# Extract logits
|
| 38 |
+
logits_per_image = outputs.logits_per_image # Shape: [1, 2]
|
| 39 |
+
print(f"Logits: {logits_per_image}")
|
| 40 |
+
|
| 41 |
+
# Apply softmax to logits to get probabilities
|
| 42 |
+
probs = logits_per_image.softmax(dim=1) # Shape: [1, 2]
|
| 43 |
+
print(f"Softmax probabilities: {probs}")
|
| 44 |
|
| 45 |
# Extract probabilities for each category
|
| 46 |
+
safe_prob = probs[0][0].item() # Extract 'safe' probability
|
| 47 |
+
unsafe_prob = probs[0][1].item() # Extract 'unsafe' probability
|
| 48 |
+
print(f"Safe probability: {safe_prob}, Unsafe probability: {unsafe_prob}")
|
| 49 |
|
| 50 |
# Normalize probabilities to ensure they sum to 100%
|
| 51 |
+
total_prob = safe_prob + unsafe_prob
|
| 52 |
+
print(f"Total probability before normalization: {total_prob}")
|
| 53 |
+
safe_percentage = (safe_prob / total_prob) * 100
|
| 54 |
+
unsafe_percentage = (unsafe_prob / total_prob) * 100
|
| 55 |
|
| 56 |
+
# Ensure the sum is exactly 100%
|
| 57 |
+
print(f"Normalized percentages: Safe={safe_percentage}%, Unsafe={unsafe_percentage}%")
|
| 58 |
return {
|
| 59 |
"safe": round(safe_percentage, 2), # Rounded to 2 decimal places
|
| 60 |
"unsafe": round(unsafe_percentage, 2)
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
|
| 68 |
+
|
| 69 |
# Step 3: Set Up Gradio Interface
|
| 70 |
iface = gr.Interface(
|
| 71 |
fn=classify_image,
|