File size: 5,491 Bytes
f30516c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

#Imports
import os
import tempfile
import streamlit as st
from PIL import Image
# from main import analyze_media, process_with_blip
from main import analyze_media
#import openai
import io
import zipfile

# Initialize OpenAI API
# from dotenv import load_dotenv
# load_dotenv()
# api_key = os.getenv("OPENAI_API_KEY")
# openai.api_key = api_key

# --- Streamlit Setup ---
st.set_page_config(layout="wide", page_title="Vision Sort")
st.sidebar.header("Configuration")

# --- Sidebar Config ---
min_confidence = st.sidebar.number_input("Confidence Threshold", min_value=0, max_value=100, value=25, step=1)
borderline_min = st.sidebar.number_input("Borderline Minimum", min_value=0, max_value=100, value=15, step=1)


# --- Main Interface ---
st.title("🔍 VisionSort Pro")
uploaded_files = st.file_uploader("Upload images/videos", type=["jpg", "jpeg", "png", "mp4", "mov"], accept_multiple_files=True)
user_prompt = st.text_input("Search prompt", placeholder="e.g. 'find the cat'")

if uploaded_files and user_prompt:
    results = {"high": [], "borderline": [], "low": []}
    temp_paths = []

    with st.spinner(f"Processing {len(uploaded_files)} files..."):
        for file in uploaded_files:
            with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[1]) as f:
                f.write(file.read())
                temp_paths.append(f.name)
                media_results = analyze_media(
                    f.name, 
                    user_prompt,
                    min_confidence,
                    (borderline_min, min_confidence)
                )

                for res in media_results:
                    results[res["status"]].append(res)

    # Sort all groups by confidence descending
    for group in results.values():
        group.sort(key=lambda r: r["confidence"], reverse=True)

    # --- Display Confident Matches ---
    if results["high"]:
        st.subheader(f"🎯 Confident Matches ({len(results['high'])})")
        cols = st.columns(4)
        for idx, res in enumerate(results["high"]):
            with cols[idx % 4]:
                st.image(Image.open(res["path"]), use_container_width=True)
                st.caption(f"{res['confidence']:.1f}% | {res['timestamp']:.2f}s")

    # --- Display Borderline Matches ---
    if results["borderline"]:
        st.subheader(f"⚠️ Potential Matches ({len(results['borderline'])})")
        #if st.checkbox("Show borderline results", True):
        if st.checkbox("Show borderline results", True, key="show_borderline"):
            cols = st.columns(4)
            for idx, res in enumerate(results["borderline"]):
                with cols[idx % 4]:
                    st.image(Image.open(res["path"]), use_container_width=True)
                    st.caption(f"{res['confidence']:.1f}% | {res['timestamp']:.2f}s")
                    # if st.button("🧠 Explain Match", key=f"blip_{idx}"):
                        # with st.expander("🔍 BLIP Analysis"):
                        #     st.write(f"**BLIP Description:** {process_with_blip(res['path'])}")
                        #     if "gpt_suggestion" in res:
                        #         st.write(f"**GPT Suggestion:** {res['gpt_suggestion']}")

    # --- Display Low Confidence Matches Only If GPT Enabled ---
    # if results["low"] and openai.api_key:
    #     st.subheader(f"❓ Low Confidence Matches ({len(results['low'])})")
    #     if st.checkbox("Show low confidence results"):
    #         for res in results["low"]:
    #             st.image(Image.open(res["path"]), use_container_width=True)
    #             st.caption(f"{res['confidence']:.1f}% | {res['timestamp']:.2f}s")
    #             if "gpt_suggestion" in res:
    #                 st.markdown(f"**💡 GPT Suggestion:** {res['gpt_suggestion']}")

    # --- Display Low Confidence Matches ------------------------------------------------------
    if results["low"]:
                st.subheader(f"❓ Low Confidence Matches ({len(results['low'])})")
               # if st.checkbox("Show low confidence results"):
                if st.checkbox("Show low confidence results", key="show_low"):
                    for res in results["low"]:
                        st.image(Image.open(res["path"]), use_container_width=True)
                        st.caption(f"{res['confidence']:.1f}% | {res['timestamp']:.2f}s")

    # --- Prepare Downloadable Results ---
    download_ready = []

    if results["high"]:
        download_ready += results["high"]

    if results["borderline"] and st.session_state.get("show_borderline", True):
        download_ready += results["borderline"]

    if results["low"] and st.session_state.get("show_low", False):
        download_ready += results["low"]

    if download_ready:
        zip_buffer = io.BytesIO()
        with zipfile.ZipFile(zip_buffer, "w") as zipf:
            for res in download_ready:
                try:
                    filename = os.path.basename(res["path"])
                    zipf.write(res["path"], arcname=filename)
                except Exception:
                    continue
        zip_buffer.seek(0)

        st.download_button(
            label="⬇️ Download Displayed Images",
            data=zip_buffer,
            file_name="visionSort_results.zip",
            mime="application/zip"
        )

        # --- Cleanup Temporary Files ---
        for path in temp_paths:
            if os.path.exists(path):
                os.unlink(path)