File size: 5,567 Bytes
b0c3a57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b7432c
 
b0c3a57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
OphthalmoCapture — Ephemeral Session Manager

All image data lives exclusively in st.session_state (RAM).
Nothing is written to disk. Data is only persisted when the user
explicitly downloads their labeling package.
"""

import streamlit as st
import uuid
import datetime
import gc


def init_session():
    """Initialize the ephemeral session data model."""
    if "session_initialized" not in st.session_state:
        st.session_state.session_initialized = True
        st.session_state.session_id = str(uuid.uuid4())  # unique per session
        st.session_state.images = {}            # {uuid_str: image_data_dict}
        st.session_state.image_order = []       # [uuid_str, ...] upload order
        st.session_state.current_image_id = None
        st.session_state.last_activity = datetime.datetime.now()
        st.session_state.doctor_name = ""
        st.session_state.confirm_end_session = False


def add_image(filename: str, image_bytes: bytes) -> str:
    """Add an uploaded image to the in-memory session store.

    Returns the generated UUID for the image.
    """
    img_id = str(uuid.uuid4())
    st.session_state.images[img_id] = {
        "filename": filename,
        "bytes": image_bytes,
        "label": None,                 # Categorical: Normal/Cataract/Bad quality/Needs dilation
        "locs_data": {},               # LOCS III: {"nuclear_opalescence": int, "nuclear_color": int, "cortical_opacity": int}
        "audio_bytes": None,           # WAV from recording (Phase 4)
        "transcription": "",           # Editable transcription text
        "transcription_original": "",  # Original Whisper output (read-only)
        "timestamp": datetime.datetime.now(),
        "labeled_by": st.session_state.get("doctor_name", ""),
    }
    st.session_state.image_order.append(img_id)
    update_activity()
    return img_id


def remove_image(img_id: str):
    """Remove a single image from the session, freeing memory."""
    if img_id in st.session_state.images:
        # Explicitly clear heavy byte fields before deletion
        st.session_state.images[img_id]["bytes"] = None
        st.session_state.images[img_id]["audio_bytes"] = None
        del st.session_state.images[img_id]

    if img_id in st.session_state.image_order:
        st.session_state.image_order.remove(img_id)

    # Update current selection if the deleted image was active
    if st.session_state.current_image_id == img_id:
        if st.session_state.image_order:
            st.session_state.current_image_id = st.session_state.image_order[0]
        else:
            st.session_state.current_image_id = None


def get_current_image():
    """Get the data dict for the currently selected image, or None."""
    img_id = st.session_state.get("current_image_id")
    if img_id and img_id in st.session_state.images:
        return st.session_state.images[img_id]
    return None


def get_current_image_id():
    """Get the UUID of the currently selected image."""
    return st.session_state.get("current_image_id")


def set_current_image(img_id: str):
    """Set the currently active image by UUID."""
    if img_id in st.session_state.images:
        st.session_state.current_image_id = img_id
        update_activity()


def get_image_count() -> int:
    """Total number of images in session."""
    return len(st.session_state.images)


def get_labeling_progress():
    """Return (labeled_count, total_count)."""
    total = len(st.session_state.images)
    labeled = sum(
        1 for img in st.session_state.images.values()
        if img["label"] is not None
    )
    return labeled, total


def has_undownloaded_data() -> bool:
    """Check if there is any data in the session."""
    return len(st.session_state.images) > 0


def update_activity():
    """Update the last activity timestamp."""
    st.session_state.last_activity = datetime.datetime.now()


def check_session_timeout(timeout_minutes: int = 30) -> bool:
    """Return True if the session has exceeded the inactivity timeout."""
    last = st.session_state.get("last_activity")
    if last:
        elapsed = (datetime.datetime.now() - last).total_seconds() / 60
        return elapsed > timeout_minutes
    return False


def clear_session():
    """Completely wipe all session data — images, audio, everything.

    Called on explicit cleanup or session timeout.
    """
    # Explicitly null out heavy byte fields to help garbage collection
    for img in st.session_state.get("images", {}).values():
        img["bytes"] = None
        img["audio_bytes"] = None
    st.session_state.clear()
    gc.collect()


def get_remaining_timeout_minutes(timeout_minutes: int = 30) -> float:
    """Return how many minutes remain before timeout, or 0 if already expired."""
    last = st.session_state.get("last_activity")
    if not last:
        return 0.0
    elapsed = (datetime.datetime.now() - last).total_seconds() / 60
    remaining = timeout_minutes - elapsed
    return max(0.0, remaining)


def get_session_data_summary() -> dict:
    """Return a summary of what data exists in the session (for warnings)."""
    images = st.session_state.get("images", {})
    total = len(images)
    labeled = sum(1 for img in images.values() if img["label"] is not None)
    with_audio = sum(1 for img in images.values() if img["audio_bytes"] is not None)
    with_text = sum(1 for img in images.values() if img["transcription"])
    return {
        "total": total,
        "labeled": labeled,
        "with_audio": with_audio,
        "with_transcription": with_text,
    }