Alief Gilang Permana Putra commited on
Commit
78bf372
·
1 Parent(s): b6a268e

feat: Add gradio app

Browse files
Files changed (2) hide show
  1. app.py +282 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import base64
4
+ import json
5
+ import tempfile
6
+ import os
7
+ from io import BytesIO
8
+ from PIL import Image
9
+
10
+ INFERENCE_API_URL = os.getenv("INFERENCE_API_URL", "http://127.0.0.1:8000")
11
+ INTERPRETATION_API_URL = os.getenv("INTERPRETATION_API_URL", "http://127.0.0.1:8080")
12
+
13
+
14
+ def get_available_models():
15
+ """Fetch available models from the FastAPI server."""
16
+ try:
17
+ response = requests.get(f"{INFERENCE_API_URL}/models", timeout=2)
18
+ if response.status_code == 200:
19
+ models_data = response.json().get("available_models", [])
20
+ # Return list of tuples: (Display Name, model_id) for the dropdown
21
+ return [(f"{m.get('name', m.get('id'))}", m.get("id")) for m in models_data]
22
+ except Exception as e:
23
+ print(f"Warning: Could not fetch models from API ({e}). Using defaults.")
24
+ # Fallback default models if API is unreachable during startup
25
+ return [("SwinV2 (swinv2)", "swinv2"), ("ViT (vit)", "vit"), ("PVTv2 (pvtv2)", "pvtv2")]
26
+
27
+ def predict(image, model_type):
28
+ if image is None:
29
+ return {"error": "Please upload an image."}, None
30
+ if not model_type:
31
+ return {"error": "Please select a model."}, None
32
+
33
+ # Convert PIL Image to Base64 string
34
+ buffered = BytesIO()
35
+ image.save(buffered, format="JPEG")
36
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
37
+
38
+ payload = {
39
+ "model_type": model_type,
40
+ "image_base64": img_str
41
+ }
42
+
43
+ try:
44
+ response = requests.post(f"{INFERENCE_API_URL}/predict", json=payload, timeout=30)
45
+ if response.status_code == 200:
46
+ data = response.json()
47
+ predictions = data.get("predictions", {})
48
+ cropped_b64 = data.get("cropped_face_base64")
49
+
50
+ cropped_img = None
51
+ if cropped_b64:
52
+ try:
53
+ img_data = base64.b64decode(cropped_b64)
54
+ cropped_img = Image.open(BytesIO(img_data)).convert("RGB")
55
+ except Exception:
56
+ pass
57
+
58
+ return predictions, cropped_img
59
+ else:
60
+ return {"error": f"HTTP {response.status_code}", "details": response.text}, None
61
+ except Exception as e:
62
+ return {"error": "Connection failed. Is the API running?", "details": str(e)}, None
63
+
64
+ # --- Interpretation API helpers ---
65
+
66
+ def get_inference_models():
67
+ """Fetch inference models from the interpretation API."""
68
+ try:
69
+ response = requests.get(f"{INTERPRETATION_API_URL}/inference-models", timeout=2)
70
+ if response.status_code == 200:
71
+ data = response.json()
72
+ if isinstance(data, dict):
73
+ return data.get("available_models", [])
74
+ return data
75
+ except Exception as e:
76
+ print(f"Warning: Could not fetch inference models ({e}).")
77
+ return ["swinv2", "vit", "pvtv2"]
78
+
79
+ def get_llm_models():
80
+ """Fetch allowed LLM models from the interpretation API."""
81
+ try:
82
+ response = requests.get(f"{INTERPRETATION_API_URL}/llm-models", timeout=2)
83
+ if response.status_code == 200:
84
+ models = response.json()
85
+ return [(m["name"], m["id"]) for m in models]
86
+ except Exception as e:
87
+ print(f"Warning: Could not fetch LLM models ({e}).")
88
+ return [("Gemma 4 31B (free)", "google/gemma-4-31b-it:free")]
89
+
90
+ def get_response_styles():
91
+ """Fetch allowed response styles from the interpretation API."""
92
+ try:
93
+ response = requests.get(f"{INTERPRETATION_API_URL}/response-styles", timeout=2)
94
+ if response.status_code == 200:
95
+ styles = response.json()
96
+ return [(s["name"], s["id"]) for s in styles]
97
+ except Exception as e:
98
+ print(f"Warning: Could not fetch response styles ({e}).")
99
+ return [("Comprehensive (ID)", "comprehensive_id")]
100
+
101
+ def interpret(image, inference_model, llm_model, style_id):
102
+ """Send image to the interpretation API via multipart/form-data."""
103
+ if image is None:
104
+ return {}, "Please upload an image."
105
+ if not inference_model:
106
+ return {}, "Please select an inference model."
107
+ if not llm_model:
108
+ return {}, "Please select an LLM model."
109
+
110
+ # Convert PIL image to bytes for multipart upload
111
+ buffered = BytesIO()
112
+ image.save(buffered, format="JPEG")
113
+ buffered.seek(0)
114
+
115
+ try:
116
+ files = {"image": ("image.jpg", buffered, "image/jpeg")}
117
+ data = {
118
+ "inference_model": inference_model,
119
+ "llm_model": llm_model,
120
+ "style_id": style_id,
121
+ }
122
+ response = requests.post(
123
+ f"{INTERPRETATION_API_URL}/interpret",
124
+ files=files,
125
+ data=data,
126
+ timeout=120,
127
+ )
128
+ if response.status_code == 200:
129
+ result = response.json()
130
+ traits = result.get("predictions", {})
131
+ interpretation = result.get("interpretation", "No interpretation returned.")
132
+ return traits, interpretation
133
+ else:
134
+ err = response.json().get("error", response.text)
135
+ return {}, f"Error {response.status_code}: {err}"
136
+ except Exception as e:
137
+ return {}, f"Connection failed. Is the interpretation API running?\n{e}"
138
+
139
+ def export_result(image, inf_model, llm_id, style_id, traits, interpretation):
140
+ """Exports the results to a JSON file and returns the temp file path."""
141
+ if not traits and not interpretation:
142
+ return None # Nothing to export
143
+
144
+ img_b64 = None
145
+ if image is not None:
146
+ buffered = BytesIO()
147
+ image.save(buffered, format="JPEG")
148
+ img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
149
+
150
+ data = {
151
+ "parameters": {
152
+ "inference_model": inf_model,
153
+ "llm_model": llm_id,
154
+ "response_style": style_id
155
+ },
156
+ "results": {
157
+ "predictions": traits,
158
+ "interpretation": interpretation
159
+ },
160
+ "image_base64": img_b64
161
+ }
162
+
163
+ fd, path = tempfile.mkstemp(suffix=".json", prefix="personality_export_")
164
+ with os.fdopen(fd, 'w', encoding='utf-8') as f:
165
+ json.dump(data, f, indent=4)
166
+
167
+ return path
168
+
169
+
170
+ # --- Build combined app ---
171
+
172
+ def build_app():
173
+ models = get_available_models()
174
+ inf_models_raw = get_inference_models()
175
+
176
+ # Map inference model IDs to display names (Name, ID)
177
+ id_to_name = {m_id: m_name for m_name, m_id in models}
178
+
179
+ inf_models = []
180
+ for m in inf_models_raw:
181
+ if isinstance(m, dict):
182
+ inf_models.append((m.get("name", m.get("id")), m.get("id")))
183
+ else:
184
+ inf_models.append((id_to_name.get(m, m), m))
185
+
186
+ llm_models = get_llm_models()
187
+ response_styles = get_response_styles()
188
+
189
+ with gr.Blocks(title="Personality Interpretation") as demo:
190
+ gr.Markdown("# Personality Analysis")
191
+
192
+ with gr.Tabs():
193
+ # ===== Tab 1: Raw Inference (existing) =====
194
+ with gr.TabItem("🔬 Inference"):
195
+ gr.Markdown("Test the raw inference API. Upload an image, choose a vision model, and get OCEAN trait scores.")
196
+ with gr.Row():
197
+ with gr.Column():
198
+ image_input = gr.Image(type="pil", label="Face Image")
199
+ with gr.Row():
200
+ model_dropdown = gr.Dropdown(
201
+ choices=models,
202
+ value=models[0][1] if models else None,
203
+ label="Inference Model"
204
+ )
205
+ refresh_btn = gr.Button("🔄 Refresh Models", size="sm")
206
+
207
+ submit_btn = gr.Button("Predict Personality", variant="primary")
208
+
209
+ with gr.Column():
210
+ output_json = gr.JSON(label="Personality Traits (OCEAN)")
211
+ cropped_output = gr.Image(type="pil", label="Extracted Face (Model Input)")
212
+
213
+ # Action mappings
214
+ submit_btn.click(
215
+ fn=predict,
216
+ inputs=[image_input, model_dropdown],
217
+ outputs=[output_json, cropped_output]
218
+ )
219
+
220
+ def refresh_models_list():
221
+ new_models = get_available_models()
222
+ return gr.update(choices=new_models, value=new_models[0][1] if new_models else None)
223
+
224
+ refresh_btn.click(
225
+ fn=refresh_models_list,
226
+ inputs=[],
227
+ outputs=[model_dropdown]
228
+ )
229
+
230
+ # ===== Tab 2: Full Interpretation =====
231
+ with gr.TabItem("✨ Interpretation"):
232
+ gr.Markdown("Upload an image and get a full personality analysis powered by vision models + LLM interpretation.")
233
+ with gr.Row():
234
+ with gr.Column():
235
+ interp_image = gr.Image(type="pil", label="Face Image")
236
+ with gr.Row():
237
+ interp_inf_dropdown = gr.Dropdown(
238
+ choices=inf_models,
239
+ value=inf_models[0][1] if inf_models else None,
240
+ label="Inference Model",
241
+ )
242
+ interp_llm_dropdown = gr.Dropdown(
243
+ choices=llm_models,
244
+ value=llm_models[0][1] if llm_models else None,
245
+ label="LLM Model",
246
+ )
247
+ style_dropdown = gr.Dropdown(
248
+ choices=response_styles,
249
+ value=response_styles[0][1] if response_styles else None,
250
+ label="Response Style"
251
+ )
252
+ interp_btn = gr.Button("Interpret Personality", variant="primary")
253
+ with gr.Column():
254
+ interp_traits = gr.JSON(label="Predicted Traits (OCEAN)")
255
+ interp_text = gr.Markdown(label="LLM Interpretation", value="*Interpretation will appear here...*")
256
+
257
+ export_btn = gr.DownloadButton("Export Result as JSON", variant="secondary")
258
+
259
+ def on_interpret(image, inf_model, llm_id, style_id):
260
+ return interpret(image, inf_model, llm_id, style_id)
261
+
262
+ interp_btn.click(
263
+ fn=on_interpret,
264
+ inputs=[interp_image, interp_inf_dropdown, interp_llm_dropdown, style_dropdown],
265
+ outputs=[interp_traits, interp_text],
266
+ )
267
+
268
+ export_btn.click(
269
+ fn=export_result,
270
+ inputs=[interp_image, interp_inf_dropdown, interp_llm_dropdown, style_dropdown, interp_traits, interp_text],
271
+ outputs=[export_btn]
272
+ )
273
+
274
+ return demo
275
+
276
+
277
+ if __name__ == "__main__":
278
+ app = build_app()
279
+ server_name = os.getenv("GRADIO_SERVER_NAME", "0.0.0.0")
280
+ server_port = int(os.getenv("GRADIO_SERVER_PORT", 7860))
281
+ app.launch(server_name=server_name, server_port=server_port, share=False, theme=gr.themes.Soft())
282
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio>=6.0.0
2
+ requests==2.32.3
3
+ Pillow==10.3.0