Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -64,7 +64,75 @@ def classify_cat(image):
|
|
| 64 |
# CLIP
|
| 65 |
clip_labels = [f"a photo of a {label}" for label in CAT_LABELS]
|
| 66 |
clip_results = clip_classifier(image, candidate_labels=clip_labels)
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
clip_output = {}
|
| 69 |
for r in clip_results:
|
| 70 |
label = r["label"].replace("a photo of a ", "").lower()
|
|
|
|
| 64 |
# CLIP
|
| 65 |
clip_labels = [f"a photo of a {label}" for label in CAT_LABELS]
|
| 66 |
clip_results = clip_classifier(image, candidate_labels=clip_labels)
|
| 67 |
+
def classify_with_openai(image_path):
|
| 68 |
+
base64_image = encode_image(image_path)
|
| 69 |
+
|
| 70 |
+
prompt = f"""
|
| 71 |
+
You are a big cat classifier.
|
| 72 |
+
|
| 73 |
+
Classify the image into exactly one of these labels:
|
| 74 |
+
|
| 75 |
+
{CAT_LABELS}
|
| 76 |
+
|
| 77 |
+
Return ONLY valid JSON.
|
| 78 |
+
Do not use markdown.
|
| 79 |
+
Do not use code fences.
|
| 80 |
+
Do not add explanations.
|
| 81 |
+
|
| 82 |
+
Required format:
|
| 83 |
+
{{"label":"one_of_{CAT_LABELS}","confidence":0.0}}
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
response = client.responses.create(
|
| 88 |
+
model="gpt-4.1-mini",
|
| 89 |
+
input=[
|
| 90 |
+
{
|
| 91 |
+
"role": "user",
|
| 92 |
+
"content": [
|
| 93 |
+
{"type": "input_text", "text": prompt},
|
| 94 |
+
{
|
| 95 |
+
"type": "input_image",
|
| 96 |
+
"image_url": f"data:image/jpeg;base64,{base64_image}"
|
| 97 |
+
}
|
| 98 |
+
]
|
| 99 |
+
}
|
| 100 |
+
]
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
text = response.output_text.strip()
|
| 104 |
+
text = text.replace("```json", "").replace("```", "").strip()
|
| 105 |
+
|
| 106 |
+
start = text.find("{")
|
| 107 |
+
end = text.rfind("}")
|
| 108 |
+
if start != -1 and end != -1 and end > start:
|
| 109 |
+
text = text[start:end+1]
|
| 110 |
+
|
| 111 |
+
result = json.loads(text)
|
| 112 |
+
|
| 113 |
+
label = str(result["label"]).strip().lower()
|
| 114 |
+
confidence = float(result["confidence"])
|
| 115 |
+
|
| 116 |
+
if label not in CAT_LABELS:
|
| 117 |
+
raise ValueError(f"Invalid label: {label}")
|
| 118 |
+
|
| 119 |
+
confidence = max(0.0, min(1.0, confidence))
|
| 120 |
+
remaining = 1.0 - confidence
|
| 121 |
+
num_other = len(CAT_LABELS) - 1
|
| 122 |
+
|
| 123 |
+
distribution = {}
|
| 124 |
+
|
| 125 |
+
for l in CAT_LABELS:
|
| 126 |
+
if l == label:
|
| 127 |
+
distribution[l] = confidence
|
| 128 |
+
else:
|
| 129 |
+
distribution[l] = remaining / num_other
|
| 130 |
+
|
| 131 |
+
return distribution
|
| 132 |
+
|
| 133 |
+
except Exception:
|
| 134 |
+
return {"unknown": 1.0}
|
| 135 |
+
|
| 136 |
clip_output = {}
|
| 137 |
for r in clip_results:
|
| 138 |
label = r["label"].replace("a photo of a ", "").lower()
|