lyfesan's picture
Update app.py
50dd4d6 verified
import gradio as gr
import requests
import base64
import json
import tempfile
import os
from io import BytesIO
from PIL import Image
INFERENCE_API_URL = os.getenv("INFERENCE_API_URL", "http://127.0.0.1:8000")
INTERPRETATION_API_URL = os.getenv("INTERPRETATION_API_URL", "http://127.0.0.1:8080")
def get_available_models():
"""Fetch available models from the FastAPI server."""
try:
response = requests.get(f"{INFERENCE_API_URL}/models", timeout=2)
if response.status_code == 200:
models_data = response.json().get("available_models", [])
# Return list of tuples: (Display Name, model_id) for the dropdown
return [(f"{m.get('name', m.get('id'))}", m.get("id")) for m in models_data]
except Exception as e:
print(f"Warning: Could not fetch models from API ({e}). Using defaults.")
# Fallback default models if API is unreachable during startup
return [("SwinV2 (swinv2)", "swinv2"), ("ViT (vit)", "vit"), ("PVTv2 (pvtv2)", "pvtv2")]
def predict(image, model_type):
if image is None:
return {"error": "Please upload an image."}, None
if not model_type:
return {"error": "Please select a model."}, None
# Convert PIL Image to Base64 string
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
payload = {
"model_type": model_type,
"image_base64": img_str
}
try:
response = requests.post(f"{INFERENCE_API_URL}/predict", json=payload, timeout=30)
if response.status_code == 200:
data = response.json()
predictions = data.get("predictions", {})
cropped_b64 = data.get("cropped_face_base64")
cropped_img = None
if cropped_b64:
try:
img_data = base64.b64decode(cropped_b64)
cropped_img = Image.open(BytesIO(img_data)).convert("RGB")
except Exception:
pass
return predictions, cropped_img
else:
return {"error": f"HTTP {response.status_code}", "details": response.text}, None
except Exception as e:
return {"error": "Connection failed. Is the API running?", "details": str(e)}, None
# --- Interpretation API helpers ---
def get_inference_models():
"""Fetch inference models from the interpretation API."""
try:
response = requests.get(f"{INTERPRETATION_API_URL}/inference-models", timeout=2)
if response.status_code == 200:
data = response.json()
if isinstance(data, dict):
return data.get("available_models", [])
return data
except Exception as e:
print(f"Warning: Could not fetch inference models ({e}).")
return ["swinv2", "vit", "pvtv2"]
def get_llm_models():
"""Fetch allowed LLM models from the interpretation API."""
try:
response = requests.get(f"{INTERPRETATION_API_URL}/llm-models", timeout=2)
if response.status_code == 200:
models = response.json()
return [(m["name"], m["id"]) for m in models]
except Exception as e:
print(f"Warning: Could not fetch LLM models ({e}).")
return [("Gemma 4 31B (free)", "google/gemma-4-31b-it:free")]
def get_response_styles():
"""Fetch allowed response styles from the interpretation API."""
try:
response = requests.get(f"{INTERPRETATION_API_URL}/response-styles", timeout=2)
if response.status_code == 200:
styles = response.json()
return [(s["name"], s["id"]) for s in styles]
except Exception as e:
print(f"Warning: Could not fetch response styles ({e}).")
return [("Comprehensive (ID)", "comprehensive_id")]
def interpret(image, inference_model, llm_model, style_id):
"""Send image to the interpretation API via multipart/form-data."""
if image is None:
return {}, "Please upload an image."
if not inference_model:
return {}, "Please select an inference model."
if not llm_model:
return {}, "Please select an LLM model."
# Convert PIL image to bytes for multipart upload
buffered = BytesIO()
image.save(buffered, format="JPEG")
buffered.seek(0)
try:
files = {"image": ("image.jpg", buffered, "image/jpeg")}
data = {
"inference_model": inference_model,
"llm_model": llm_model,
"style_id": style_id,
}
response = requests.post(
f"{INTERPRETATION_API_URL}/interpret",
files=files,
data=data,
timeout=120,
)
if response.status_code == 200:
result = response.json()
traits = result.get("predictions", {})
interpretation = result.get("interpretation", "No interpretation returned.")
return traits, interpretation
else:
err = response.json().get("error", response.text)
return {}, f"Error {response.status_code}: {err}"
except Exception as e:
return {}, f"Connection failed. Is the interpretation API running?\n{e}"
def export_result(image, inf_model, llm_id, style_id, traits, interpretation):
"""Exports the results to a JSON file and returns the temp file path."""
if not traits and not interpretation:
return None # Nothing to export
img_b64 = None
if image is not None:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
data = {
"parameters": {
"inference_model": inf_model,
"llm_model": llm_id,
"response_style": style_id
},
"results": {
"predictions": traits,
"interpretation": interpretation
},
"image_base64": img_b64
}
fd, path = tempfile.mkstemp(suffix=".json", prefix="personality_export_")
with os.fdopen(fd, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=4)
return path
# --- Build combined app ---
def build_app():
models = get_available_models()
inf_models_raw = get_inference_models()
# Map inference model IDs to display names (Name, ID)
id_to_name = {m_id: m_name for m_name, m_id in models}
inf_models = []
for m in inf_models_raw:
if isinstance(m, dict):
inf_models.append((m.get("name", m.get("id")), m.get("id")))
else:
inf_models.append((id_to_name.get(m, m), m))
llm_models = get_llm_models()
response_styles = get_response_styles()
with gr.Blocks(title="Personality Interpretation") as demo:
gr.Markdown("# Personality Analysis")
with gr.Tabs():
# ===== Tab 1: Raw Inference (existing) =====
with gr.TabItem("🔬 Inference"):
gr.Markdown("Test the raw inference API. Upload an image, choose a vision model, and get OCEAN trait scores.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Face Image")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=models,
value=models[0][1] if models else None,
label="Inference Model"
)
refresh_btn = gr.Button("🔄 Refresh Models", size="sm")
submit_btn = gr.Button("Predict Personality", variant="primary")
with gr.Column():
output_json = gr.JSON(label="Personality Traits (OCEAN)")
cropped_output = gr.Image(type="pil", label="Extracted Face (Model Input)")
# Action mappings
submit_btn.click(
fn=predict,
inputs=[image_input, model_dropdown],
outputs=[output_json, cropped_output]
)
def refresh_models_list():
new_models = get_available_models()
return gr.update(choices=new_models, value=new_models[0][1] if new_models else None)
refresh_btn.click(
fn=refresh_models_list,
inputs=[],
outputs=[model_dropdown]
)
# ===== Tab 2: Full Interpretation =====
with gr.TabItem("✨ Interpretation"):
gr.Markdown("Upload an image and get a full personality analysis powered by vision models + LLM interpretation.")
with gr.Row():
with gr.Column():
interp_image = gr.Image(type="pil", label="Face Image")
with gr.Row():
interp_inf_dropdown = gr.Dropdown(
choices=inf_models,
value=inf_models[0][1] if inf_models else None,
label="Inference Model",
)
interp_llm_dropdown = gr.Dropdown(
choices=llm_models,
value=llm_models[0][1] if llm_models else None,
label="LLM Model",
)
style_dropdown = gr.Dropdown(
choices=response_styles,
value=response_styles[0][1] if response_styles else None,
label="Response Style"
)
interp_btn = gr.Button("Interpret Personality", variant="primary")
with gr.Column():
interp_traits = gr.JSON(label="Predicted Traits (OCEAN)")
interp_text = gr.Markdown(label="LLM Interpretation", value="*Interpretation will appear here...*")
export_btn = gr.DownloadButton("Export Result as JSON", variant="secondary")
def on_interpret(image, inf_model, llm_id, style_id):
return interpret(image, inf_model, llm_id, style_id)
interp_btn.click(
fn=on_interpret,
inputs=[interp_image, interp_inf_dropdown, interp_llm_dropdown, style_dropdown],
outputs=[interp_traits, interp_text],
)
export_btn.click(
fn=export_result,
inputs=[interp_image, interp_inf_dropdown, interp_llm_dropdown, style_dropdown, interp_traits, interp_text],
outputs=[export_btn]
)
return demo
if __name__ == "__main__":
app = build_app()
server_name = os.getenv("GRADIO_SERVER_NAME", "0.0.0.0")
server_port = int(os.getenv("GRADIO_SERVER_PORT", 7860))
app.launch(server_name=server_name, server_port=server_port, share=False, theme=gr.themes.Soft())