DKatheesrupan commited on
Commit
ce2bb40
·
verified ·
1 Parent(s): e431a80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -1
app.py CHANGED
@@ -64,7 +64,75 @@ def classify_cat(image):
64
  # CLIP
65
  clip_labels = [f"a photo of a {label}" for label in CAT_LABELS]
66
  clip_results = clip_classifier(image, candidate_labels=clip_labels)
67
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  clip_output = {}
69
  for r in clip_results:
70
  label = r["label"].replace("a photo of a ", "").lower()
 
64
  # CLIP
65
  clip_labels = [f"a photo of a {label}" for label in CAT_LABELS]
66
  clip_results = clip_classifier(image, candidate_labels=clip_labels)
67
+ def classify_with_openai(image_path):
68
+ base64_image = encode_image(image_path)
69
+
70
+ prompt = f"""
71
+ You are a big cat classifier.
72
+
73
+ Classify the image into exactly one of these labels:
74
+
75
+ {CAT_LABELS}
76
+
77
+ Return ONLY valid JSON.
78
+ Do not use markdown.
79
+ Do not use code fences.
80
+ Do not add explanations.
81
+
82
+ Required format:
83
+ {{"label":"one_of_{CAT_LABELS}","confidence":0.0}}
84
+ """
85
+
86
+ try:
87
+ response = client.responses.create(
88
+ model="gpt-4.1-mini",
89
+ input=[
90
+ {
91
+ "role": "user",
92
+ "content": [
93
+ {"type": "input_text", "text": prompt},
94
+ {
95
+ "type": "input_image",
96
+ "image_url": f"data:image/jpeg;base64,{base64_image}"
97
+ }
98
+ ]
99
+ }
100
+ ]
101
+ )
102
+
103
+ text = response.output_text.strip()
104
+ text = text.replace("```json", "").replace("```", "").strip()
105
+
106
+ start = text.find("{")
107
+ end = text.rfind("}")
108
+ if start != -1 and end != -1 and end > start:
109
+ text = text[start:end+1]
110
+
111
+ result = json.loads(text)
112
+
113
+ label = str(result["label"]).strip().lower()
114
+ confidence = float(result["confidence"])
115
+
116
+ if label not in CAT_LABELS:
117
+ raise ValueError(f"Invalid label: {label}")
118
+
119
+ confidence = max(0.0, min(1.0, confidence))
120
+ remaining = 1.0 - confidence
121
+ num_other = len(CAT_LABELS) - 1
122
+
123
+ distribution = {}
124
+
125
+ for l in CAT_LABELS:
126
+ if l == label:
127
+ distribution[l] = confidence
128
+ else:
129
+ distribution[l] = remaining / num_other
130
+
131
+ return distribution
132
+
133
+ except Exception:
134
+ return {"unknown": 1.0}
135
+
136
  clip_output = {}
137
  for r in clip_results:
138
  label = r["label"].replace("a photo of a ", "").lower()