| import gradio as gr |
| import requests |
| import base64 |
| import json |
| import tempfile |
| import os |
| from io import BytesIO |
| from PIL import Image |
|
|
| INFERENCE_API_URL = os.getenv("INFERENCE_API_URL", "http://127.0.0.1:8000") |
| INTERPRETATION_API_URL = os.getenv("INTERPRETATION_API_URL", "http://127.0.0.1:8080") |
|
|
|
|
| def get_available_models(): |
| """Fetch available models from the FastAPI server.""" |
| try: |
| response = requests.get(f"{INFERENCE_API_URL}/models", timeout=2) |
| if response.status_code == 200: |
| models_data = response.json().get("available_models", []) |
| |
| return [(f"{m.get('name', m.get('id'))}", m.get("id")) for m in models_data] |
| except Exception as e: |
| print(f"Warning: Could not fetch models from API ({e}). Using defaults.") |
| |
| return [("SwinV2 (swinv2)", "swinv2"), ("ViT (vit)", "vit"), ("PVTv2 (pvtv2)", "pvtv2")] |
|
|
| def predict(image, model_type): |
| if image is None: |
| return {"error": "Please upload an image."}, None |
| if not model_type: |
| return {"error": "Please select a model."}, None |
| |
| |
| buffered = BytesIO() |
| image.save(buffered, format="JPEG") |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| |
| payload = { |
| "model_type": model_type, |
| "image_base64": img_str |
| } |
| |
| try: |
| response = requests.post(f"{INFERENCE_API_URL}/predict", json=payload, timeout=30) |
| if response.status_code == 200: |
| data = response.json() |
| predictions = data.get("predictions", {}) |
| cropped_b64 = data.get("cropped_face_base64") |
| |
| cropped_img = None |
| if cropped_b64: |
| try: |
| img_data = base64.b64decode(cropped_b64) |
| cropped_img = Image.open(BytesIO(img_data)).convert("RGB") |
| except Exception: |
| pass |
| |
| return predictions, cropped_img |
| else: |
| return {"error": f"HTTP {response.status_code}", "details": response.text}, None |
| except Exception as e: |
| return {"error": "Connection failed. Is the API running?", "details": str(e)}, None |
|
|
| |
|
|
| def get_inference_models(): |
| """Fetch inference models from the interpretation API.""" |
| try: |
| response = requests.get(f"{INTERPRETATION_API_URL}/inference-models", timeout=2) |
| if response.status_code == 200: |
| data = response.json() |
| if isinstance(data, dict): |
| return data.get("available_models", []) |
| return data |
| except Exception as e: |
| print(f"Warning: Could not fetch inference models ({e}).") |
| return ["swinv2", "vit", "pvtv2"] |
|
|
| def get_llm_models(): |
| """Fetch allowed LLM models from the interpretation API.""" |
| try: |
| response = requests.get(f"{INTERPRETATION_API_URL}/llm-models", timeout=2) |
| if response.status_code == 200: |
| models = response.json() |
| return [(m["name"], m["id"]) for m in models] |
| except Exception as e: |
| print(f"Warning: Could not fetch LLM models ({e}).") |
| return [("Gemma 4 31B (free)", "google/gemma-4-31b-it:free")] |
|
|
| def get_response_styles(): |
| """Fetch allowed response styles from the interpretation API.""" |
| try: |
| response = requests.get(f"{INTERPRETATION_API_URL}/response-styles", timeout=2) |
| if response.status_code == 200: |
| styles = response.json() |
| return [(s["name"], s["id"]) for s in styles] |
| except Exception as e: |
| print(f"Warning: Could not fetch response styles ({e}).") |
| return [("Comprehensive (ID)", "comprehensive_id")] |
|
|
| def interpret(image, inference_model, llm_model, style_id): |
| """Send image to the interpretation API via multipart/form-data.""" |
| if image is None: |
| return {}, "Please upload an image." |
| if not inference_model: |
| return {}, "Please select an inference model." |
| if not llm_model: |
| return {}, "Please select an LLM model." |
|
|
| |
| buffered = BytesIO() |
| image.save(buffered, format="JPEG") |
| buffered.seek(0) |
|
|
| try: |
| files = {"image": ("image.jpg", buffered, "image/jpeg")} |
| data = { |
| "inference_model": inference_model, |
| "llm_model": llm_model, |
| "style_id": style_id, |
| } |
| response = requests.post( |
| f"{INTERPRETATION_API_URL}/interpret", |
| files=files, |
| data=data, |
| timeout=120, |
| ) |
| if response.status_code == 200: |
| result = response.json() |
| traits = result.get("predictions", {}) |
| interpretation = result.get("interpretation", "No interpretation returned.") |
| return traits, interpretation |
| else: |
| err = response.json().get("error", response.text) |
| return {}, f"Error {response.status_code}: {err}" |
| except Exception as e: |
| return {}, f"Connection failed. Is the interpretation API running?\n{e}" |
|
|
| def export_result(image, inf_model, llm_id, style_id, traits, interpretation): |
| """Exports the results to a JSON file and returns the temp file path.""" |
| if not traits and not interpretation: |
| return None |
| |
| img_b64 = None |
| if image is not None: |
| buffered = BytesIO() |
| image.save(buffered, format="JPEG") |
| img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| |
| data = { |
| "parameters": { |
| "inference_model": inf_model, |
| "llm_model": llm_id, |
| "response_style": style_id |
| }, |
| "results": { |
| "predictions": traits, |
| "interpretation": interpretation |
| }, |
| "image_base64": img_b64 |
| } |
| |
| fd, path = tempfile.mkstemp(suffix=".json", prefix="personality_export_") |
| with os.fdopen(fd, 'w', encoding='utf-8') as f: |
| json.dump(data, f, indent=4) |
| |
| return path |
|
|
|
|
| |
|
|
| def build_app(): |
| models = get_available_models() |
| inf_models_raw = get_inference_models() |
| |
| |
| id_to_name = {m_id: m_name for m_name, m_id in models} |
| |
| inf_models = [] |
| for m in inf_models_raw: |
| if isinstance(m, dict): |
| inf_models.append((m.get("name", m.get("id")), m.get("id"))) |
| else: |
| inf_models.append((id_to_name.get(m, m), m)) |
|
|
| llm_models = get_llm_models() |
| response_styles = get_response_styles() |
|
|
| with gr.Blocks(title="Personality Interpretation") as demo: |
| gr.Markdown("# Personality Analysis") |
|
|
| with gr.Tabs(): |
| |
| with gr.TabItem("🔬 Inference"): |
| gr.Markdown("Test the raw inference API. Upload an image, choose a vision model, and get OCEAN trait scores.") |
| with gr.Row(): |
| with gr.Column(): |
| image_input = gr.Image(type="pil", label="Face Image") |
| with gr.Row(): |
| model_dropdown = gr.Dropdown( |
| choices=models, |
| value=models[0][1] if models else None, |
| label="Inference Model" |
| ) |
| refresh_btn = gr.Button("🔄 Refresh Models", size="sm") |
|
|
| submit_btn = gr.Button("Predict Personality", variant="primary") |
| |
| with gr.Column(): |
| output_json = gr.JSON(label="Personality Traits (OCEAN)") |
| cropped_output = gr.Image(type="pil", label="Extracted Face (Model Input)") |
| |
| |
| submit_btn.click( |
| fn=predict, |
| inputs=[image_input, model_dropdown], |
| outputs=[output_json, cropped_output] |
| ) |
| |
| def refresh_models_list(): |
| new_models = get_available_models() |
| return gr.update(choices=new_models, value=new_models[0][1] if new_models else None) |
| |
| refresh_btn.click( |
| fn=refresh_models_list, |
| inputs=[], |
| outputs=[model_dropdown] |
| ) |
|
|
| |
| with gr.TabItem("✨ Interpretation"): |
| gr.Markdown("Upload an image and get a full personality analysis powered by vision models + LLM interpretation.") |
| with gr.Row(): |
| with gr.Column(): |
| interp_image = gr.Image(type="pil", label="Face Image") |
| with gr.Row(): |
| interp_inf_dropdown = gr.Dropdown( |
| choices=inf_models, |
| value=inf_models[0][1] if inf_models else None, |
| label="Inference Model", |
| ) |
| interp_llm_dropdown = gr.Dropdown( |
| choices=llm_models, |
| value=llm_models[0][1] if llm_models else None, |
| label="LLM Model", |
| ) |
| style_dropdown = gr.Dropdown( |
| choices=response_styles, |
| value=response_styles[0][1] if response_styles else None, |
| label="Response Style" |
| ) |
| interp_btn = gr.Button("Interpret Personality", variant="primary") |
| with gr.Column(): |
| interp_traits = gr.JSON(label="Predicted Traits (OCEAN)") |
| interp_text = gr.Markdown(label="LLM Interpretation", value="*Interpretation will appear here...*") |
| |
| export_btn = gr.DownloadButton("Export Result as JSON", variant="secondary") |
|
|
| def on_interpret(image, inf_model, llm_id, style_id): |
| return interpret(image, inf_model, llm_id, style_id) |
|
|
| interp_btn.click( |
| fn=on_interpret, |
| inputs=[interp_image, interp_inf_dropdown, interp_llm_dropdown, style_dropdown], |
| outputs=[interp_traits, interp_text], |
| ) |
| |
| export_btn.click( |
| fn=export_result, |
| inputs=[interp_image, interp_inf_dropdown, interp_llm_dropdown, style_dropdown, interp_traits, interp_text], |
| outputs=[export_btn] |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| app = build_app() |
| server_name = os.getenv("GRADIO_SERVER_NAME", "0.0.0.0") |
| server_port = int(os.getenv("GRADIO_SERVER_PORT", 7860)) |
| app.launch(server_name=server_name, server_port=server_port, share=False, theme=gr.themes.Soft()) |
|
|
|
|