Spaces:
Sleeping
Sleeping
updating new version with supabase and vlm
Browse files- app.py +3 -1
- config/settings.py +5 -2
- logic/data_utils.py +36 -4
- logic/handlers.py +87 -18
- logic/supabase_client.py +193 -0
- logic/vlm.py +440 -0
- requirements.txt +67 -12
- ui/layout.py +308 -43
- ui/main_page.py +40 -0
- ui/selection_page.py +20 -2
app.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
# import spacy.cli
|
| 2 |
# spacy.cli.download("ja_core_news_sm")
|
| 3 |
# spacy.cli.download("zh_core_web_sm")
|
|
|
|
|
|
|
| 4 |
import spacy_udpipe
|
| 5 |
spacy_udpipe.download("ja")
|
| 6 |
spacy_udpipe.download("zh")
|
|
@@ -16,7 +18,7 @@ metadata = load_metadata()
|
|
| 16 |
|
| 17 |
demo = build_ui(concepts, metadata, HF_API_TOKEN, HF_DATASET_NAME)
|
| 18 |
# demo.launch()
|
| 19 |
-
demo.launch(debug=False)
|
| 20 |
|
| 21 |
demo.close()
|
| 22 |
# gr.close_all()
|
|
|
|
| 1 |
# import spacy.cli
|
| 2 |
# spacy.cli.download("ja_core_news_sm")
|
| 3 |
# spacy.cli.download("zh_core_web_sm")
|
| 4 |
+
import os
|
| 5 |
+
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
| 6 |
import spacy_udpipe
|
| 7 |
spacy_udpipe.download("ja")
|
| 8 |
spacy_udpipe.download("zh")
|
|
|
|
| 18 |
|
| 19 |
demo = build_ui(concepts, metadata, HF_API_TOKEN, HF_DATASET_NAME)
|
| 20 |
# demo.launch()
|
| 21 |
+
demo.launch(debug=False, server_port=7861)
|
| 22 |
|
| 23 |
demo.close()
|
| 24 |
# gr.close_all()
|
config/settings.py
CHANGED
|
@@ -3,6 +3,9 @@ import os
|
|
| 3 |
|
| 4 |
load_dotenv()
|
| 5 |
|
| 6 |
-
HF_API_TOKEN = os.getenv("
|
| 7 |
HF_DATASET_NAME = os.getenv("HF_DATASET_NAME")
|
| 8 |
-
LOCAL_DS_DIRECTORY_PATH = os.getenv("LOCAL_DS_DIRECTORY_PATH")
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
load_dotenv()
|
| 5 |
|
| 6 |
+
HF_API_TOKEN = os.getenv("HF_TOKEN")
|
| 7 |
HF_DATASET_NAME = os.getenv("HF_DATASET_NAME")
|
| 8 |
+
LOCAL_DS_DIRECTORY_PATH = os.getenv("LOCAL_DS_DIRECTORY_PATH")
|
| 9 |
+
SUPABASE_URL: str = os.getenv("SUPABASE_URL")
|
| 10 |
+
SUPABASE_KEY: str = os.getenv("SUPABASE_KEY")
|
| 11 |
+
REDIRECT_TO_URL: str = os.getenv("REDIRECT_TO_URL")
|
logic/data_utils.py
CHANGED
|
@@ -8,7 +8,7 @@ import uuid
|
|
| 8 |
import gradio as gr
|
| 9 |
from PIL import Image
|
| 10 |
import numpy as np
|
| 11 |
-
|
| 12 |
|
| 13 |
def load_concepts(path="data/concepts.json"):
|
| 14 |
with open(path, encoding='utf-8') as f:
|
|
@@ -53,6 +53,9 @@ class CustomHFDatasetSaver:
|
|
| 53 |
self.local_ds_folder = local_ds_folder
|
| 54 |
os.makedirs(self.local_ds_folder, exist_ok=True)
|
| 55 |
|
|
|
|
|
|
|
|
|
|
| 56 |
self.data_outputs = data_outputs # list of components to read values from
|
| 57 |
|
| 58 |
# create scheduler to commit the data to the hub every x minutes
|
|
@@ -63,6 +66,27 @@ class CustomHFDatasetSaver:
|
|
| 63 |
every=1,
|
| 64 |
token=self.api_token,
|
| 65 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
def validate_data(self, values_dic):
|
| 68 |
"""
|
|
@@ -166,10 +190,16 @@ class CustomHFDatasetSaver:
|
|
| 166 |
values_dic["id"] = f'{country}_{language}_{category}_{concept}_{current_timestamp}'
|
| 167 |
|
| 168 |
#prepare the main directory of the sample
|
| 169 |
-
if
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
else:
|
| 172 |
sample_dir = os.path.join("anonymous_users", values_dic["country"], values_dic["language"], str(uuid.uuid4()), str(current_timestamp))
|
|
|
|
| 173 |
|
| 174 |
os.makedirs(os.path.join(self.local_ds_folder, sample_dir), exist_ok=True)
|
| 175 |
|
|
@@ -217,6 +247,8 @@ class CustomHFDatasetSaver:
|
|
| 217 |
# "image_file": image_file_path_on_hub,
|
| 218 |
"image_url": values_dic['image_url'] or "",
|
| 219 |
"caption": values_dic['caption'] or "",
|
|
|
|
|
|
|
| 220 |
"country": values_dic['country'] or "",
|
| 221 |
"language": values_dic['language'] or "",
|
| 222 |
"category": values_dic['category'] or "",
|
|
@@ -227,7 +259,7 @@ class CustomHFDatasetSaver:
|
|
| 227 |
"category_4_concepts": values_dic.get('category_4_concepts') or [""],
|
| 228 |
"category_5_concepts": values_dic.get('category_5_concepts') or [""],
|
| 229 |
"timestamp": current_timestamp,
|
| 230 |
-
"username":
|
| 231 |
"password": values_dic['password'] or "",
|
| 232 |
"id": values_dic['id'],
|
| 233 |
"excluded": False if values_dic.get('excluded') is None else bool(values_dic.get('excluded')),
|
|
|
|
| 8 |
import gradio as gr
|
| 9 |
from PIL import Image
|
| 10 |
import numpy as np
|
| 11 |
+
from logic.supabase_client import auth_handler
|
| 12 |
|
| 13 |
def load_concepts(path="data/concepts.json"):
|
| 14 |
with open(path, encoding='utf-8') as f:
|
|
|
|
| 53 |
self.local_ds_folder = local_ds_folder
|
| 54 |
os.makedirs(self.local_ds_folder, exist_ok=True)
|
| 55 |
|
| 56 |
+
# Migrate any existing JSON files to include new VLM fields
|
| 57 |
+
self._migrate_existing()
|
| 58 |
+
|
| 59 |
self.data_outputs = data_outputs # list of components to read values from
|
| 60 |
|
| 61 |
# create scheduler to commit the data to the hub every x minutes
|
|
|
|
| 66 |
every=1,
|
| 67 |
token=self.api_token,
|
| 68 |
)
|
| 69 |
+
|
| 70 |
+
def _migrate_existing(self):
|
| 71 |
+
"""
|
| 72 |
+
Ensure all existing JSON sample files have the same schema
|
| 73 |
+
by adding missing keys for 'vlm_caption' and 'vlm_feedback'.
|
| 74 |
+
"""
|
| 75 |
+
for root, _, files in os.walk(self.local_ds_folder):
|
| 76 |
+
for fname in files:
|
| 77 |
+
if fname.endswith('.json'):
|
| 78 |
+
fpath = os.path.join(root, fname)
|
| 79 |
+
with open(fpath, 'r+', encoding='utf-8') as f:
|
| 80 |
+
data = json.load(f)
|
| 81 |
+
updated = False
|
| 82 |
+
for key in ['vlm_caption', 'vlm_feedback']:
|
| 83 |
+
if key not in data:
|
| 84 |
+
data[key] = ""
|
| 85 |
+
updated = True
|
| 86 |
+
if updated:
|
| 87 |
+
f.seek(0)
|
| 88 |
+
json.dump(data, f, indent=2)
|
| 89 |
+
f.truncate()
|
| 90 |
|
| 91 |
def validate_data(self, values_dic):
|
| 92 |
"""
|
|
|
|
| 190 |
values_dic["id"] = f'{country}_{language}_{category}_{concept}_{current_timestamp}'
|
| 191 |
|
| 192 |
#prepare the main directory of the sample
|
| 193 |
+
# here we check if the user is logged in or not
|
| 194 |
+
user_info = auth_handler.is_logged_in(values_dic.get("client", None))
|
| 195 |
+
print(f"User info: {user_info}")
|
| 196 |
+
if user_info['success']:
|
| 197 |
+
# sample_dir = os.path.join("logged_in_users", values_dic["country"], values_dic["language"], values_dic["username"], str(current_timestamp))
|
| 198 |
+
sample_dir = os.path.join("logged_in_users", values_dic["country"], values_dic["language"], user_info['email'], str(current_timestamp))
|
| 199 |
+
print(f"Sample directory for logged in user: {sample_dir}")
|
| 200 |
else:
|
| 201 |
sample_dir = os.path.join("anonymous_users", values_dic["country"], values_dic["language"], str(uuid.uuid4()), str(current_timestamp))
|
| 202 |
+
print(f"Sample directory: {sample_dir}")
|
| 203 |
|
| 204 |
os.makedirs(os.path.join(self.local_ds_folder, sample_dir), exist_ok=True)
|
| 205 |
|
|
|
|
| 247 |
# "image_file": image_file_path_on_hub,
|
| 248 |
"image_url": values_dic['image_url'] or "",
|
| 249 |
"caption": values_dic['caption'] or "",
|
| 250 |
+
"vlm_caption": values_dic['vlm_caption'] or "",
|
| 251 |
+
"vlm_feedback": values_dic['vlm_feedback'] or "",
|
| 252 |
"country": values_dic['country'] or "",
|
| 253 |
"language": values_dic['language'] or "",
|
| 254 |
"category": values_dic['category'] or "",
|
|
|
|
| 259 |
"category_4_concepts": values_dic.get('category_4_concepts') or [""],
|
| 260 |
"category_5_concepts": values_dic.get('category_5_concepts') or [""],
|
| 261 |
"timestamp": current_timestamp,
|
| 262 |
+
"username": user_info['email'] if user_info['success'] else "",
|
| 263 |
"password": values_dic['password'] or "",
|
| 264 |
"id": values_dic['id'],
|
| 265 |
"excluded": False if values_dic.get('excluded') is None else bool(values_dic.get('excluded')),
|
logic/handlers.py
CHANGED
|
@@ -4,6 +4,7 @@ import io
|
|
| 4 |
import PIL
|
| 5 |
import requests
|
| 6 |
from typing import Literal
|
|
|
|
| 7 |
|
| 8 |
from datasets import load_dataset, concatenate_datasets, Image
|
| 9 |
from data.lang2eng_map import lang2eng_mapping
|
|
@@ -12,6 +13,7 @@ import gradio as gr
|
|
| 12 |
import bcrypt
|
| 13 |
from config.settings import HF_API_TOKEN
|
| 14 |
from huggingface_hub import snapshot_download
|
|
|
|
| 15 |
# from .blur import blur_faces, detect_faces
|
| 16 |
from retinaface import RetinaFace
|
| 17 |
from gradio_modal import Modal
|
|
@@ -71,13 +73,13 @@ def clear_data(message: Literal["submit", "remove"] | None = None):
|
|
| 71 |
gr.Info("If you logged in, you will soon see it at the bottom of the page, where you can edit it or delete it", title="Thank you for submitting your data! π", duration=5)
|
| 72 |
elif message == "remove":
|
| 73 |
gr.Info("", title="Your data has been deleted! ποΈ", duration=5)
|
| 74 |
-
return (None, None, None, None, None, gr.update(value=None),
|
| 75 |
gr.update(value=[]), gr.update(value=[]), gr.update(value=[]),
|
| 76 |
gr.update(value=[]), gr.update(value=[]))
|
| 77 |
|
| 78 |
|
| 79 |
def exit():
|
| 80 |
-
return (None, None, None, gr.Dataset(samples=[]), gr.Markdown("**Loading your data, please wait ...**"),
|
| 81 |
gr.update(value=None), gr.update(value=None), [None, None, "", ""], gr.update(value=None),
|
| 82 |
gr.update(value=None), gr.update(value=None),
|
| 83 |
gr.update(value=None), gr.update(value=None), gr.update(value=None),
|
|
@@ -87,9 +89,8 @@ def exit():
|
|
| 87 |
def validate_metadata(country, language):
|
| 88 |
# Perform your validation logic here
|
| 89 |
if country is None or language is None:
|
| 90 |
-
return gr.
|
| 91 |
-
|
| 92 |
-
return gr.Button("Proceed", interactive=True)
|
| 93 |
|
| 94 |
|
| 95 |
def validate_inputs(image, ori_img, concept): # is_blurred
|
|
@@ -129,6 +130,30 @@ def validate_inputs(image, ori_img, concept): # is_blurred
|
|
| 129 |
|
| 130 |
return gr.Button("Submit", variant="primary", interactive=True), result_image, ori_img # is_blurred
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
def count_words(caption, language):
|
| 134 |
match language:
|
|
@@ -152,8 +177,14 @@ def add_prefix(example, column_name, prefix):
|
|
| 152 |
example[column_name] = (f"{prefix}/" + example[column_name])
|
| 153 |
return example
|
| 154 |
|
| 155 |
-
def update_user_data(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
|
|
|
| 157 |
datasets_list = []
|
| 158 |
# Try loading local dataset
|
| 159 |
try:
|
|
@@ -191,18 +222,19 @@ def update_user_data(username, password, country, language_choice, HF_DATASET_NA
|
|
| 191 |
# Handle all empty
|
| 192 |
if not datasets_list:
|
| 193 |
if username: # User is logged in but has no data
|
| 194 |
-
return gr.Dataset(samples=[]), gr.Markdown("<p style='color: red;'>No data available for this user. Please upload an image.</p>")
|
| 195 |
else: # No user logged in
|
| 196 |
-
return gr.Dataset(samples=[]), gr.Markdown("")
|
| 197 |
|
| 198 |
dataset = concatenate_datasets(datasets_list)
|
| 199 |
# TODO: we should link username with password and language and country, otherwise there will be an error when loading with different language and clicking on the example
|
| 200 |
-
if username
|
| 201 |
-
user_dataset = dataset.filter(lambda x: x['username'] == username
|
| 202 |
user_dataset = user_dataset.sort('timestamp', reverse=True)
|
| 203 |
# Show only unique entries (most recent)
|
| 204 |
user_ids = set()
|
| 205 |
samples = []
|
|
|
|
| 206 |
for d in user_dataset:
|
| 207 |
if d['id'] in user_ids:
|
| 208 |
continue
|
|
@@ -229,6 +261,10 @@ def update_user_data(username, password, country, language_choice, HF_DATASET_NA
|
|
| 229 |
d['image_file'], d['image_url'], d['caption'] or "", d['country'],
|
| 230 |
d['language'], d['category'], d['concept'], additional_concepts_by_category, d['id']] # d['is_blurred']
|
| 231 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
# return gr.Dataset(samples=samples), None
|
| 233 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 234 |
# Clean up the βAdditional Conceptsβ column (index 7)
|
|
@@ -255,10 +291,14 @@ def update_user_data(username, password, country, language_choice, HF_DATASET_NA
|
|
| 255 |
row_copy[7] = ", ".join(vals)
|
| 256 |
cleaned.append(row_copy)
|
| 257 |
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
else:
|
| 260 |
# TODO: should we show the entire dataset instead? What about "other data" tab?
|
| 261 |
-
return gr.Dataset(samples=[]), None
|
| 262 |
|
| 263 |
|
| 264 |
def update_language(local_storage, metadata_dict, concepts_dict):
|
|
@@ -357,7 +397,7 @@ def update_intro_language(selected_country, selected_language, intro_markdown, m
|
|
| 357 |
return gr.Markdown(INTRO_TEXT)
|
| 358 |
|
| 359 |
|
| 360 |
-
def handle_click_example(user_examples, concepts_dict):
|
| 361 |
# print("handle_click_example")
|
| 362 |
# print(user_examples)
|
| 363 |
# ex = [item for item in user_examples]
|
|
@@ -365,7 +405,6 @@ def handle_click_example(user_examples, concepts_dict):
|
|
| 365 |
# 1) Turn the flat string in slot 7 back into a list-of-lists
|
| 366 |
ex = list(user_examples)
|
| 367 |
raw_ac = ex[7] if len(ex) > 7 else ""
|
| 368 |
-
|
| 369 |
country_btn = ex[3]
|
| 370 |
language_btn = ex[4]
|
| 371 |
concepts = concepts_dict[country_btn][language_btn]
|
|
@@ -441,7 +480,13 @@ def handle_click_example(user_examples, concepts_dict):
|
|
| 441 |
# dropdown_values.append(None)
|
| 442 |
|
| 443 |
# Need to return values for each category dropdown
|
| 444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
# return [
|
| 446 |
# image_inp,
|
| 447 |
# image_url_inp,
|
|
@@ -535,8 +580,8 @@ def blur_selected_faces(image, blur_faces_ids, faces_info, face_img, faces_count
|
|
| 535 |
parsed_faces_ids = [f"face_{val.split(':')[-1].strip()}" for val in parsed_faces_ids]
|
| 536 |
|
| 537 |
# Base blur amount and bounds
|
| 538 |
-
MIN_BLUR =
|
| 539 |
-
MAX_BLUR =
|
| 540 |
|
| 541 |
blurring_start = time.time()
|
| 542 |
# Process each face
|
|
@@ -688,4 +733,28 @@ def check_exclude_fn(image):
|
|
| 688 |
|
| 689 |
def has_user_json(username, country,language_choice, local_ds_directory_path):
|
| 690 |
"""Check if JSON files exist for username pattern."""
|
| 691 |
-
return bool(glob.glob(os.path.join(local_ds_directory_path, "logged_in_users", country, language_choice, username, "**", "*.json"), recursive=True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import PIL
|
| 5 |
import requests
|
| 6 |
from typing import Literal
|
| 7 |
+
from logic.supabase_client import auth_handler
|
| 8 |
|
| 9 |
from datasets import load_dataset, concatenate_datasets, Image
|
| 10 |
from data.lang2eng_map import lang2eng_mapping
|
|
|
|
| 13 |
import bcrypt
|
| 14 |
from config.settings import HF_API_TOKEN
|
| 15 |
from huggingface_hub import snapshot_download
|
| 16 |
+
from logic.vlm import vlm_manager
|
| 17 |
# from .blur import blur_faces, detect_faces
|
| 18 |
from retinaface import RetinaFace
|
| 19 |
from gradio_modal import Modal
|
|
|
|
| 73 |
gr.Info("If you logged in, you will soon see it at the bottom of the page, where you can edit it or delete it", title="Thank you for submitting your data! π", duration=5)
|
| 74 |
elif message == "remove":
|
| 75 |
gr.Info("", title="Your data has been deleted! ποΈ", duration=5)
|
| 76 |
+
return (None, None, None, gr.update(value=None), gr.update(value=None, visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True), None, None, gr.update(value=None),
|
| 77 |
gr.update(value=[]), gr.update(value=[]), gr.update(value=[]),
|
| 78 |
gr.update(value=[]), gr.update(value=[]))
|
| 79 |
|
| 80 |
|
| 81 |
def exit():
|
| 82 |
+
return (None, None, None, gr.update(value=None), gr.update(value=None, visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True), gr.Dataset(samples=[]), gr.Markdown("**Loading your data, please wait ...**"),
|
| 83 |
gr.update(value=None), gr.update(value=None), [None, None, "", ""], gr.update(value=None),
|
| 84 |
gr.update(value=None), gr.update(value=None),
|
| 85 |
gr.update(value=None), gr.update(value=None), gr.update(value=None),
|
|
|
|
| 89 |
def validate_metadata(country, language):
|
| 90 |
# Perform your validation logic here
|
| 91 |
if country is None or language is None:
|
| 92 |
+
return gr.update(interactive=False)
|
| 93 |
+
return gr.update(interactive=True)
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
def validate_inputs(image, ori_img, concept): # is_blurred
|
|
|
|
| 130 |
|
| 131 |
return gr.Button("Submit", variant="primary", interactive=True), result_image, ori_img # is_blurred
|
| 132 |
|
| 133 |
+
def generate_vlm_caption(image, model_name="SmolVLM-500M"): # processor, model
|
| 134 |
+
"""
|
| 135 |
+
Generate a caption for the given image using a Vision-Language Model.
|
| 136 |
+
Uses the global VLMManager for efficient model loading and caching.
|
| 137 |
+
"""
|
| 138 |
+
if image is None:
|
| 139 |
+
gr.Warning("β οΈ Please upload an image first.", duration=5)
|
| 140 |
+
return None, gr.update(visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True)
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
# Use the global VLMManager to load/get the model
|
| 144 |
+
vlm_manager.load_model(model_name)
|
| 145 |
+
caption = vlm_manager.generate_caption(image)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"Error generating caption: {e}. Cleaning up memory and try again.")
|
| 148 |
+
gr.Warning(f"β οΈ Error generating caption: {e} due to memory issues. Please try again.", duration=5)
|
| 149 |
+
# vlm_manager.cleanup_memory()
|
| 150 |
+
return None, gr.update(visible=False), gr.update(visible=False), gr.update(interactive=True), gr.update(interactive=True)
|
| 151 |
+
finally: # For now, let's cleanup memory after each generation
|
| 152 |
+
vlm_manager.cleanup_memory()
|
| 153 |
+
|
| 154 |
+
# print(caption)
|
| 155 |
+
|
| 156 |
+
return caption, gr.update(visible=True), gr.update(visible=True), gr.update(interactive=False), gr.update(interactive=False)
|
| 157 |
|
| 158 |
def count_words(caption, language):
|
| 159 |
match language:
|
|
|
|
| 177 |
example[column_name] = (f"{prefix}/" + example[column_name])
|
| 178 |
return example
|
| 179 |
|
| 180 |
+
def update_user_data(client , country, language_choice, HF_DATASET_NAME, local_ds_directory_path):
|
| 181 |
+
user_info = auth_handler.is_logged_in(client)
|
| 182 |
+
print(f"User info: {user_info}")
|
| 183 |
+
if not user_info['success']:
|
| 184 |
+
print("User is not logged in or session expired.")
|
| 185 |
+
return gr.Dataset(samples=[]), None, None
|
| 186 |
|
| 187 |
+
username = user_info['email']
|
| 188 |
datasets_list = []
|
| 189 |
# Try loading local dataset
|
| 190 |
try:
|
|
|
|
| 222 |
# Handle all empty
|
| 223 |
if not datasets_list:
|
| 224 |
if username: # User is logged in but has no data
|
| 225 |
+
return gr.Dataset(samples=[]), gr.Markdown("<p style='color: red;'>No data available for this user. Please upload an image.</p>"), None
|
| 226 |
else: # No user logged in
|
| 227 |
+
return gr.Dataset(samples=[]), gr.Markdown(""), None
|
| 228 |
|
| 229 |
dataset = concatenate_datasets(datasets_list)
|
| 230 |
# TODO: we should link username with password and language and country, otherwise there will be an error when loading with different language and clicking on the example
|
| 231 |
+
if username:
|
| 232 |
+
user_dataset = dataset.filter(lambda x: x['username'] == username)
|
| 233 |
user_dataset = user_dataset.sort('timestamp', reverse=True)
|
| 234 |
# Show only unique entries (most recent)
|
| 235 |
user_ids = set()
|
| 236 |
samples = []
|
| 237 |
+
vlm_captions = dict()
|
| 238 |
for d in user_dataset:
|
| 239 |
if d['id'] in user_ids:
|
| 240 |
continue
|
|
|
|
| 261 |
d['image_file'], d['image_url'], d['caption'] or "", d['country'],
|
| 262 |
d['language'], d['category'], d['concept'], additional_concepts_by_category, d['id']] # d['is_blurred']
|
| 263 |
)
|
| 264 |
+
|
| 265 |
+
if 'vlm_caption' in d:
|
| 266 |
+
vlm_captions[d['id']] = d.get('vlm_caption', "")
|
| 267 |
+
|
| 268 |
# return gr.Dataset(samples=samples), None
|
| 269 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 270 |
# Clean up the βAdditional Conceptsβ column (index 7)
|
|
|
|
| 291 |
row_copy[7] = ", ".join(vals)
|
| 292 |
cleaned.append(row_copy)
|
| 293 |
|
| 294 |
+
# check if vlm_captions is an empty dictionary
|
| 295 |
+
if not vlm_captions:
|
| 296 |
+
vlm_captions = None
|
| 297 |
+
|
| 298 |
+
return gr.Dataset(samples=cleaned), None, vlm_captions
|
| 299 |
else:
|
| 300 |
# TODO: should we show the entire dataset instead? What about "other data" tab?
|
| 301 |
+
return gr.Dataset(samples=[]), None, None
|
| 302 |
|
| 303 |
|
| 304 |
def update_language(local_storage, metadata_dict, concepts_dict):
|
|
|
|
| 397 |
return gr.Markdown(INTRO_TEXT)
|
| 398 |
|
| 399 |
|
| 400 |
+
def handle_click_example(user_examples, vlm_captions, concepts_dict):
|
| 401 |
# print("handle_click_example")
|
| 402 |
# print(user_examples)
|
| 403 |
# ex = [item for item in user_examples]
|
|
|
|
| 405 |
# 1) Turn the flat string in slot 7 back into a list-of-lists
|
| 406 |
ex = list(user_examples)
|
| 407 |
raw_ac = ex[7] if len(ex) > 7 else ""
|
|
|
|
| 408 |
country_btn = ex[3]
|
| 409 |
language_btn = ex[4]
|
| 410 |
concepts = concepts_dict[country_btn][language_btn]
|
|
|
|
| 480 |
# dropdown_values.append(None)
|
| 481 |
|
| 482 |
# Need to return values for each category dropdown
|
| 483 |
+
|
| 484 |
+
vlm_caption = None
|
| 485 |
+
if vlm_captions:
|
| 486 |
+
if exampleid_btn in vlm_captions:
|
| 487 |
+
vlm_caption = vlm_captions[exampleid_btn]
|
| 488 |
+
|
| 489 |
+
return [image_inp, image_url_inp, long_caption_inp, exampleid_btn, category_btn, concept_btn] + additional_concepts_by_category + [True] + [vlm_caption] # loading_example flag + vlm_caption
|
| 490 |
# return [
|
| 491 |
# image_inp,
|
| 492 |
# image_url_inp,
|
|
|
|
| 580 |
parsed_faces_ids = [f"face_{val.split(':')[-1].strip()}" for val in parsed_faces_ids]
|
| 581 |
|
| 582 |
# Base blur amount and bounds
|
| 583 |
+
MIN_BLUR = 131 # Minimum blur amount (must be odd)
|
| 584 |
+
MAX_BLUR = 351 # Maximum blur amount (must be odd)
|
| 585 |
|
| 586 |
blurring_start = time.time()
|
| 587 |
# Process each face
|
|
|
|
| 733 |
|
| 734 |
def has_user_json(username, country,language_choice, local_ds_directory_path):
|
| 735 |
"""Check if JSON files exist for username pattern."""
|
| 736 |
+
return bool(glob.glob(os.path.join(local_ds_directory_path, "logged_in_users", country, language_choice, username, "**", "*.json"), recursive=True))
|
| 737 |
+
|
| 738 |
+
def submit_button_clicked(vlm_output):
|
| 739 |
+
|
| 740 |
+
if vlm_output is None or vlm_output == '':
|
| 741 |
+
return Modal(visible=True), Modal(visible=False)
|
| 742 |
+
else:
|
| 743 |
+
return Modal(visible=False), Modal(visible=True)
|
| 744 |
+
# def submit_button_clicked(vlm_output, save_fn, data_outputs):
|
| 745 |
+
# if vlm_output is None:
|
| 746 |
+
# return Modal(visible=True)
|
| 747 |
+
# else:
|
| 748 |
+
# try:
|
| 749 |
+
# save_fn(list(data_outputs.values()))
|
| 750 |
+
# except Exception as e:
|
| 751 |
+
# gr.Error(f"β οΈ Error saving data: {e}")
|
| 752 |
+
|
| 753 |
+
# try:
|
| 754 |
+
# image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, exampleid_btn, category_btn, concept_btn, \
|
| 755 |
+
# category_concept_dropdowns0, category_concept_dropdowns1, category_concept_dropdowns2, category_concept_dropdowns3, \
|
| 756 |
+
# category_concept_dropdowns4 = clear_data("submit")
|
| 757 |
+
# except Exception as e:
|
| 758 |
+
# gr.Error(f"β οΈ Error clearing data: {e}")
|
| 759 |
+
|
| 760 |
+
# return Modal(visible=False)
|
logic/supabase_client.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from supabase import create_client, Client
|
| 3 |
+
import os
|
| 4 |
+
from config.settings import SUPABASE_URL, SUPABASE_KEY, REDIRECT_TO_URL
|
| 5 |
+
import traceback
|
| 6 |
+
from supabase.lib.client_options import ClientOptions
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# --- Supabase Authentication Class ---
|
| 10 |
+
|
| 11 |
+
class SupabaseAuth:
|
| 12 |
+
"""A class to handle Supabase authentication logic."""
|
| 13 |
+
def __init__(self, url: str, key: str):
|
| 14 |
+
self.url = url
|
| 15 |
+
self.key = key
|
| 16 |
+
try:
|
| 17 |
+
self.client: Client = create_client(url, key)
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"Error creating Supabase client: {e}")
|
| 20 |
+
self.client = None
|
| 21 |
+
|
| 22 |
+
def login(self, email: str, password: str):
|
| 23 |
+
"""
|
| 24 |
+
Attempts to log in a user and returns a user-specific client.
|
| 25 |
+
"""
|
| 26 |
+
if not self.client:
|
| 27 |
+
return {'success': False, 'data': None, 'message': "Supabase client not initialized."}
|
| 28 |
+
try:
|
| 29 |
+
response = self.client.auth.sign_in_with_password({"email": email, "password": password})
|
| 30 |
+
user_session = response.session
|
| 31 |
+
|
| 32 |
+
# Create a new, authenticated client for this user
|
| 33 |
+
authenticated_client = create_client(
|
| 34 |
+
self.url,
|
| 35 |
+
self.key,
|
| 36 |
+
# options={"headers": {"Authorization": f"Bearer {user_session.access_token}"}}
|
| 37 |
+
options=ClientOptions(
|
| 38 |
+
headers={"Authorization": f"Bearer {user_session.access_token}"},
|
| 39 |
+
)
|
| 40 |
+
)
|
| 41 |
+
authenticated_client.auth.set_session(user_session.access_token, user_session.refresh_token)
|
| 42 |
+
|
| 43 |
+
session_data = {
|
| 44 |
+
"refresh_token": user_session.refresh_token,
|
| 45 |
+
"user_email": user_session.user.email,
|
| 46 |
+
"client": authenticated_client
|
| 47 |
+
}
|
| 48 |
+
return {'success': True, 'data': session_data, 'message': f"Welcome, {user_session.user.email}!"}
|
| 49 |
+
except Exception as e:
|
| 50 |
+
# print(f"Error logging in: {e}")
|
| 51 |
+
# traceback.print_exc()
|
| 52 |
+
# Handle specific error messages for better user feedback
|
| 53 |
+
return {'success': False, 'data': None, 'message': f"Login failed: {e}"}
|
| 54 |
+
|
| 55 |
+
def sign_up(self, email: str, password: str):
|
| 56 |
+
"""Signs up a new user."""
|
| 57 |
+
if not self.client:
|
| 58 |
+
return {'success': False, 'message': "Supabase client not initialized."}
|
| 59 |
+
try:
|
| 60 |
+
# Supabase sign_up returns a session if email confirmation is disabled,
|
| 61 |
+
# or just a user object if it's enabled. We'll just return a success message.
|
| 62 |
+
self.client.auth.sign_up({
|
| 63 |
+
"email": email,
|
| 64 |
+
"password": password,
|
| 65 |
+
})
|
| 66 |
+
return {'success': True, 'message': 'Sign up successful! You can login now.'}
|
| 67 |
+
except Exception as e:
|
| 68 |
+
return {'success': False, 'message': f"Sign up failed: {e}"}
|
| 69 |
+
|
| 70 |
+
def restore_session(self, refresh_token: str):
|
| 71 |
+
"""
|
| 72 |
+
Attempts to restore a session using a refresh token.
|
| 73 |
+
"""
|
| 74 |
+
if not self.client:
|
| 75 |
+
return {'success': False, 'data': None, 'message': "Supabase client not initialized."}
|
| 76 |
+
try:
|
| 77 |
+
response = self.client.auth.refresh_session(refresh_token)
|
| 78 |
+
user_session = response.session
|
| 79 |
+
|
| 80 |
+
authenticated_client = create_client(
|
| 81 |
+
self.url,
|
| 82 |
+
self.key,
|
| 83 |
+
options=ClientOptions(
|
| 84 |
+
headers={"Authorization": f"Bearer {user_session.access_token}"},
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
authenticated_client.auth.set_session(user_session.access_token, user_session.refresh_token)
|
| 88 |
+
|
| 89 |
+
session_data = {
|
| 90 |
+
"refresh_token": user_session.refresh_token,
|
| 91 |
+
"user_email": user_session.user.email,
|
| 92 |
+
"client": authenticated_client
|
| 93 |
+
}
|
| 94 |
+
print("Session restored successfully:", session_data)
|
| 95 |
+
return {'success': True, 'data': session_data, 'message': f"Welcome, {user_session.user.email}!"}
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print("failed to restore session:", e)
|
| 98 |
+
return {'success': False, 'data': None, 'message': f"Failed to restore session: {e}"}
|
| 99 |
+
|
| 100 |
+
def logout(self, user_client: Client):
|
| 101 |
+
"""Signs out the user from Supabase, invalidating the token."""
|
| 102 |
+
if not user_client:
|
| 103 |
+
return {'success': False, 'message': 'No user client provided to log out.'}
|
| 104 |
+
try:
|
| 105 |
+
user_client.auth.sign_out()
|
| 106 |
+
return {'success': True, 'message': 'Successfully signed out from Supabase.'}
|
| 107 |
+
except Exception as e:
|
| 108 |
+
# It's often safe to ignore errors here (e.g., if token already expired)
|
| 109 |
+
# but we'll log it for debugging.
|
| 110 |
+
print(f"Error signing out from Supabase: {e}")
|
| 111 |
+
return {'success': False, 'message': f'Error signing out: {e}'}
|
| 112 |
+
|
| 113 |
+
def change_password(self, user_client: Client, new_password: str):
|
| 114 |
+
"""Changes the user's password."""
|
| 115 |
+
if not user_client:
|
| 116 |
+
return {'success': False, 'message': 'No user client provided to change password.'}
|
| 117 |
+
try:
|
| 118 |
+
user_client.auth.update_user({"password": new_password})
|
| 119 |
+
return {'success': True, 'message': 'Password changed successfully.'}
|
| 120 |
+
except Exception as e:
|
| 121 |
+
return {'success': False, 'message': f'Error changing password: {e}'}
|
| 122 |
+
|
| 123 |
+
def is_logged_in(self, user_client: Client):
|
| 124 |
+
"""Checks if a user is currently authenticated and returns their email."""
|
| 125 |
+
print("Checking if user is logged in...", user_client)
|
| 126 |
+
if not user_client:
|
| 127 |
+
return {'success': False, 'email': None, 'message': 'No user client provided.'}
|
| 128 |
+
try:
|
| 129 |
+
user_response = user_client.auth.get_user()
|
| 130 |
+
user = user_response.user
|
| 131 |
+
if user:
|
| 132 |
+
return {'success': True, 'email': user.email, 'message': f'Logged in as: {user.email}'}
|
| 133 |
+
else:
|
| 134 |
+
return {'success': False, 'email': None, 'message': 'User is not logged in.'}
|
| 135 |
+
except Exception as e:
|
| 136 |
+
# This might happen if the token has expired and can't be refreshed.
|
| 137 |
+
return {'success': False, 'email': None, 'message': f'Authentication check failed: {e}'}
|
| 138 |
+
|
| 139 |
+
def reset_password_for_email(self, email: str):
|
| 140 |
+
"""
|
| 141 |
+
Sends a password reset email to the specified address.
|
| 142 |
+
"""
|
| 143 |
+
if not self.client:
|
| 144 |
+
return {'success': False, 'message': "Supabase client not initialized."}
|
| 145 |
+
try:
|
| 146 |
+
self.client.auth.reset_password_for_email(
|
| 147 |
+
email,
|
| 148 |
+
{
|
| 149 |
+
"redirect_to": str(REDIRECT_TO_URL),
|
| 150 |
+
}
|
| 151 |
+
)
|
| 152 |
+
return {'success': True, 'message': "Password reset email sent. Check your inbox!"}
|
| 153 |
+
except Exception as e:
|
| 154 |
+
return {'success': False, 'message': f"Failed to send reset email: {e}"}
|
| 155 |
+
|
| 156 |
+
def retrieve_session_from_tokens(self, access_token: str, refresh_token: str):
|
| 157 |
+
"""
|
| 158 |
+
Retrieves a session from an access token and refresh token.
|
| 159 |
+
This is typically used after a password recovery link is clicked.
|
| 160 |
+
"""
|
| 161 |
+
if not self.client:
|
| 162 |
+
return {'success': False, 'data': None, 'message': "Supabase client not initialized."}
|
| 163 |
+
try:
|
| 164 |
+
# Set the session on the main client to verify tokens and get user info
|
| 165 |
+
self.client.auth.set_session(access_token, refresh_token)
|
| 166 |
+
user_response = self.client.auth.get_user()
|
| 167 |
+
user = user_response.user
|
| 168 |
+
|
| 169 |
+
if not user:
|
| 170 |
+
return {'success': False, 'data': None, 'message': "Could not retrieve user from tokens."}
|
| 171 |
+
|
| 172 |
+
# Create a new, authenticated client for this user, similar to login
|
| 173 |
+
authenticated_client = create_client(
|
| 174 |
+
self.url,
|
| 175 |
+
self.key,
|
| 176 |
+
options=ClientOptions(
|
| 177 |
+
headers={"Authorization": f"Bearer {access_token}"},
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
authenticated_client.auth.set_session(access_token, refresh_token)
|
| 181 |
+
|
| 182 |
+
session_data = {
|
| 183 |
+
"refresh_token": refresh_token,
|
| 184 |
+
"user_email": user.email,
|
| 185 |
+
"client": authenticated_client
|
| 186 |
+
}
|
| 187 |
+
return {'success': True, 'data': session_data, 'message': f"Welcome, {user.email}!"}
|
| 188 |
+
except Exception as e:
|
| 189 |
+
return {'success': False, 'data': None, 'message': f"Failed to retrieve session from tokens: {e}"}
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
auth_handler = SupabaseAuth(SUPABASE_URL, SUPABASE_KEY)
|
| 193 |
+
|
logic/vlm.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, torchvision.transforms as T
|
| 2 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from transformers import TorchAoConfig, Qwen2_5_VLForConditionalGeneration, Gemma3ForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq, AutoModel
|
| 5 |
+
from qwen_vl_utils import process_vision_info
|
| 6 |
+
import gc
|
| 7 |
+
# from transformers.image_utils import load_image
|
| 8 |
+
|
| 9 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 10 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 11 |
+
|
| 12 |
+
class VLMManager:
|
| 13 |
+
"""
|
| 14 |
+
A manager class for Vision-Language Models that handles model loading,
|
| 15 |
+
caching, and dynamic switching between different models.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, default_model: str = "Gemma3-4B"):
|
| 19 |
+
"""
|
| 20 |
+
Initialize the VLM Manager with a default model.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
default_model (str): The default model to load initially.
|
| 24 |
+
"""
|
| 25 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
self.current_model_name = None
|
| 27 |
+
self.processor = None
|
| 28 |
+
self.tokenizer = None # Initialize tokenizer attribute
|
| 29 |
+
self.model = None
|
| 30 |
+
|
| 31 |
+
self.system_message = """
|
| 32 |
+
You are an expert cultural-aware image-analysis assistant. For every image:
|
| 33 |
+
1. Output exactly 40 words in total.
|
| 34 |
+
2. Use a single paragraph (no lists or bullet points).
|
| 35 |
+
3. Describe Who (appearance/emotion), What (action), and Where (setting).
|
| 36 |
+
4. Do NOT include opinions or speculations.
|
| 37 |
+
5. If you go over 40 words, shorten or remove non-essential details.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
self.user_prompt = """
|
| 41 |
+
Given this image, please provide an image description of around 40 words with extensive and detailed visual information.
|
| 42 |
+
|
| 43 |
+
Descriptions must be objective: focus on how you would describe the image to someone who can't see it, without your own opinions/speculations.
|
| 44 |
+
|
| 45 |
+
The text needs to include the main concept and describe the content of the image in detail by including:
|
| 46 |
+
- Who?: The visual appearance and observable emotions (e.g., "is smiling") of persons and animals.
|
| 47 |
+
- What?: The actions performed in the image.
|
| 48 |
+
- Where?: The setting of the image, including the size, color, and relationships between objects.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
# Load the default model
|
| 52 |
+
self.load_model(default_model)
|
| 53 |
+
|
| 54 |
+
def load_model(self, model_name: str):
|
| 55 |
+
"""
|
| 56 |
+
Load a VLM model. If the model is already loaded, return the cached version.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
model_name (str): The name of the model to load.
|
| 60 |
+
"""
|
| 61 |
+
# If the requested model is already loaded, no need to reload
|
| 62 |
+
if self.current_model_name == model_name and self.model is not None:
|
| 63 |
+
print(f"Model {model_name} is already loaded, using cached version.")
|
| 64 |
+
if self.current_model_name == "InternVL3_5-8B":
|
| 65 |
+
return self.tokenizer, self.model
|
| 66 |
+
else:
|
| 67 |
+
return self.processor, self.model
|
| 68 |
+
|
| 69 |
+
print(f"Loading model: {model_name}")
|
| 70 |
+
|
| 71 |
+
# Clear current model from memory if exists
|
| 72 |
+
if self.model is not None:
|
| 73 |
+
del self.model
|
| 74 |
+
self.model = None
|
| 75 |
+
if self.current_model_name == "InternVL3_5-8B":
|
| 76 |
+
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
| 77 |
+
del self.tokenizer
|
| 78 |
+
self.tokenizer = None
|
| 79 |
+
else:
|
| 80 |
+
if hasattr(self, 'processor') and self.processor is not None:
|
| 81 |
+
del self.processor
|
| 82 |
+
self.processor = None
|
| 83 |
+
# Force garbage collection and clear CUDA cache
|
| 84 |
+
gc.collect()
|
| 85 |
+
if torch.cuda.is_available():
|
| 86 |
+
torch.cuda.empty_cache()
|
| 87 |
+
torch.cuda.synchronize() # Wait for all operations to complete
|
| 88 |
+
|
| 89 |
+
# Load the new model
|
| 90 |
+
if model_name == "SmolVLM-500M":
|
| 91 |
+
self.processor, self.model = self._load_smolvlm_model("HuggingFaceTB/SmolVLM-500M-Instruct")
|
| 92 |
+
elif model_name == "Qwen2.5-VL-7B":
|
| 93 |
+
self.processor, self.model = self._load_qwen25_model("Qwen/Qwen2.5-VL-7B-Instruct")
|
| 94 |
+
elif model_name == "InternVL3_5-8B":
|
| 95 |
+
self.tokenizer, self.model = self._load_internvl35_model("OpenGVLab/InternVL3_5-8B-Instruct")
|
| 96 |
+
elif model_name == "Gemma3-4B":
|
| 97 |
+
self.processor, self.model = self._load_gemma3_model("google/gemma-3-4b-it")
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError(f"Model {model_name} is not supported or not available.")
|
| 100 |
+
|
| 101 |
+
self.current_model_name = model_name
|
| 102 |
+
print(f"Successfully loaded model: {model_name}")
|
| 103 |
+
|
| 104 |
+
def generate_caption(self, image):
|
| 105 |
+
"""
|
| 106 |
+
Generate a caption for the given image using the loaded model.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
processor: The processor for the model.
|
| 110 |
+
model: The model to use for generating the caption.
|
| 111 |
+
image: The image to generate a caption for.
|
| 112 |
+
"""
|
| 113 |
+
if self.current_model_name == "SmolVLM-500M":
|
| 114 |
+
return self._inference_smolvlm_model(image)
|
| 115 |
+
elif self.current_model_name == "Qwen2.5-VL-7B":
|
| 116 |
+
return self._inference_qwen25_model(image)
|
| 117 |
+
elif self.current_model_name == "InternVL3_5-8B":
|
| 118 |
+
return self._inference_internvl35_model(image)
|
| 119 |
+
elif self.current_model_name == "Gemma3-4B":
|
| 120 |
+
return self._inference_gemma3_model(image)
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError(f"Model {self.current_model_name} is not supported or not available.")
|
| 123 |
+
|
| 124 |
+
def get_current_model(self):
|
| 125 |
+
"""
|
| 126 |
+
Get the currently loaded model and processor.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
tuple: A tuple containing (processor, model, model_name).
|
| 130 |
+
"""
|
| 131 |
+
return self.processor, self.model, self.current_model_name
|
| 132 |
+
|
| 133 |
+
def cleanup_memory(self):
|
| 134 |
+
"""
|
| 135 |
+
Explicit memory cleanup method that can be called to free GPU memory.
|
| 136 |
+
"""
|
| 137 |
+
if self.model is not None:
|
| 138 |
+
del self.model
|
| 139 |
+
self.model = None
|
| 140 |
+
if hasattr(self, 'processor') and self.processor is not None:
|
| 141 |
+
del self.processor
|
| 142 |
+
self.processor = None
|
| 143 |
+
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
| 144 |
+
del self.tokenizer
|
| 145 |
+
self.tokenizer = None
|
| 146 |
+
|
| 147 |
+
self.current_model_name = None
|
| 148 |
+
|
| 149 |
+
# Force cleanup
|
| 150 |
+
gc.collect()
|
| 151 |
+
if torch.cuda.is_available():
|
| 152 |
+
torch.cuda.empty_cache()
|
| 153 |
+
torch.cuda.synchronize()
|
| 154 |
+
|
| 155 |
+
print("Memory cleanup completed.")
|
| 156 |
+
|
| 157 |
+
#########################################################
|
| 158 |
+
## Load functions
|
| 159 |
+
|
| 160 |
+
def _load_smolvlm_model(self, model_name):
|
| 161 |
+
"""Load SmolVLM model."""
|
| 162 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
| 163 |
+
model = AutoModelForVision2Seq.from_pretrained(
|
| 164 |
+
model_name,
|
| 165 |
+
_attn_implementation="eager"
|
| 166 |
+
).to(self.device)
|
| 167 |
+
model.eval()
|
| 168 |
+
return processor, model
|
| 169 |
+
|
| 170 |
+
def _load_qwen25_model(self, model_name):
|
| 171 |
+
"""Load Qwen2.5-VL model."""
|
| 172 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 173 |
+
model_name, torch_dtype="auto", device_map="auto"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 177 |
+
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 178 |
+
# "Qwen/Qwen2.5-VL-7B-Instruct",
|
| 179 |
+
# torch_dtype=torch.bfloat16,
|
| 180 |
+
# attn_implementation="flash_attention_2",
|
| 181 |
+
# device_map="auto",
|
| 182 |
+
# )
|
| 183 |
+
|
| 184 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
| 185 |
+
model.eval()
|
| 186 |
+
return processor, model
|
| 187 |
+
|
| 188 |
+
def _load_internvl35_model(self, model_name):
|
| 189 |
+
"""Load InternVL3.5 model."""
|
| 190 |
+
# Load tokenizer (InternVL uses tokenizer instead of processor for text)
|
| 191 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 192 |
+
|
| 193 |
+
# Load the model using AutoModel
|
| 194 |
+
model = AutoModel.from_pretrained(
|
| 195 |
+
model_name,
|
| 196 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
|
| 197 |
+
low_cpu_mem_usage=True,
|
| 198 |
+
use_flash_attn=False, # True set False if CUDA mismatch
|
| 199 |
+
trust_remote_code=True,
|
| 200 |
+
device_map="auto"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
model.eval()
|
| 204 |
+
|
| 205 |
+
# Return tokenizer as processor for consistency with the interface
|
| 206 |
+
return tokenizer, model
|
| 207 |
+
|
| 208 |
+
def _load_gemma3_model(self, model_name):
|
| 209 |
+
"""Load Gemma3 model."""
|
| 210 |
+
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
| 211 |
+
model = Gemma3ForConditionalGeneration.from_pretrained(
|
| 212 |
+
model_name,
|
| 213 |
+
device_map="auto",
|
| 214 |
+
quantization_config=quantization_config
|
| 215 |
+
)
|
| 216 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
| 217 |
+
model.eval()
|
| 218 |
+
return processor, model
|
| 219 |
+
|
| 220 |
+
#########################################################
|
| 221 |
+
## Inference functions
|
| 222 |
+
def check_processor_and_model(self):
|
| 223 |
+
if self.processor is None or self.model is None:
|
| 224 |
+
raise ValueError("Processor and model must be loaded before generating a caption.")
|
| 225 |
+
|
| 226 |
+
def _inference_qwen25_model(self, image):
|
| 227 |
+
"""Inference Qwen2.5-VL model."""
|
| 228 |
+
self.check_processor_and_model()
|
| 229 |
+
messages = [
|
| 230 |
+
{
|
| 231 |
+
"role": "system",
|
| 232 |
+
"content": [{"type": "text", "text": self.system_message}]
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"role": "user",
|
| 236 |
+
"content": [
|
| 237 |
+
{
|
| 238 |
+
"type": "image",
|
| 239 |
+
"image": Image.fromarray(image),
|
| 240 |
+
},
|
| 241 |
+
{"type": "text", "text": self.user_prompt},
|
| 242 |
+
],
|
| 243 |
+
}
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
# Preparation for inference
|
| 247 |
+
text = self.processor.apply_chat_template(
|
| 248 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 249 |
+
)
|
| 250 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 251 |
+
inputs = self.processor(
|
| 252 |
+
text=[text],
|
| 253 |
+
images=image_inputs,
|
| 254 |
+
videos=video_inputs,
|
| 255 |
+
padding=True,
|
| 256 |
+
return_tensors="pt",
|
| 257 |
+
)
|
| 258 |
+
inputs = inputs.to(self.model.device)
|
| 259 |
+
|
| 260 |
+
# Inference: Generation of the output
|
| 261 |
+
generated_ids = self.model.generate(**inputs, max_new_tokens=128)
|
| 262 |
+
generated_ids_trimmed = [
|
| 263 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 264 |
+
]
|
| 265 |
+
caption = self.processor.batch_decode(
|
| 266 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 267 |
+
)[0]
|
| 268 |
+
|
| 269 |
+
# Clean up tensors to free GPU memory
|
| 270 |
+
del inputs, generated_ids, generated_ids_trimmed
|
| 271 |
+
if torch.cuda.is_available():
|
| 272 |
+
torch.cuda.empty_cache()
|
| 273 |
+
|
| 274 |
+
return caption
|
| 275 |
+
|
| 276 |
+
def _inference_gemma3_model(self, image):
|
| 277 |
+
"""Inference Gemma3 model."""
|
| 278 |
+
self.check_processor_and_model()
|
| 279 |
+
messages = [
|
| 280 |
+
{
|
| 281 |
+
"role": "system",
|
| 282 |
+
"content": [{"type": "text", "text": self.system_message}]
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"role": "user",
|
| 286 |
+
"content": [
|
| 287 |
+
{"type": "image", "image": Image.fromarray(image)},
|
| 288 |
+
{"type": "text", "text": self.user_prompt}
|
| 289 |
+
]
|
| 290 |
+
}
|
| 291 |
+
]
|
| 292 |
+
|
| 293 |
+
inputs = self.processor.apply_chat_template(
|
| 294 |
+
messages, add_generation_prompt=True, tokenize=True,
|
| 295 |
+
return_dict=True, return_tensors="pt"
|
| 296 |
+
).to(self.model.device, dtype=torch.bfloat16)
|
| 297 |
+
|
| 298 |
+
input_len = inputs["input_ids"].shape[-1]
|
| 299 |
+
|
| 300 |
+
with torch.inference_mode():
|
| 301 |
+
generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
| 302 |
+
generation = generation[0][input_len:]
|
| 303 |
+
|
| 304 |
+
caption = self.processor.decode(generation, skip_special_tokens=True)
|
| 305 |
+
|
| 306 |
+
# Clean up tensors to free GPU memory
|
| 307 |
+
del inputs, generation
|
| 308 |
+
if torch.cuda.is_available():
|
| 309 |
+
torch.cuda.empty_cache()
|
| 310 |
+
|
| 311 |
+
return caption
|
| 312 |
+
|
| 313 |
+
def _inference_smolvlm_model(self, image):
|
| 314 |
+
self.check_processor_and_model()
|
| 315 |
+
messages = [
|
| 316 |
+
{
|
| 317 |
+
"role": "system",
|
| 318 |
+
"content": self.system_message
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"role": "user",
|
| 322 |
+
"content": [
|
| 323 |
+
{"type": "image"},
|
| 324 |
+
{"type": "text", "text": self.user_prompt}
|
| 325 |
+
]
|
| 326 |
+
}
|
| 327 |
+
]
|
| 328 |
+
|
| 329 |
+
# Prepare inputs
|
| 330 |
+
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
|
| 331 |
+
inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
|
| 332 |
+
inputs = inputs.to(self.model.device)
|
| 333 |
+
|
| 334 |
+
# Generate outputs
|
| 335 |
+
gen_kwargs = {
|
| 336 |
+
"max_new_tokens": 200, # plenty for ~40 words
|
| 337 |
+
# "early_stopping": True, # stop at first EOS
|
| 338 |
+
# "no_repeat_ngram_size": 3, # discourage loops
|
| 339 |
+
# "length_penalty": 0.8, # slightly favor brevity
|
| 340 |
+
# "eos_token_id": processor.tokenizer.eos_token_id,
|
| 341 |
+
# "pad_token_id": processor.tokenizer.eos_token_id,
|
| 342 |
+
}
|
| 343 |
+
generated_ids = self.model.generate(**inputs, **gen_kwargs) # max_new_tokens=500)
|
| 344 |
+
generated_texts = self.processor.batch_decode(
|
| 345 |
+
generated_ids,
|
| 346 |
+
skip_special_tokens=True,
|
| 347 |
+
)[0]
|
| 348 |
+
|
| 349 |
+
# Extract only what the assistant said
|
| 350 |
+
if "Assistant:" in generated_texts:
|
| 351 |
+
caption = generated_texts.split("Assistant:", 1)[1].strip()
|
| 352 |
+
else:
|
| 353 |
+
caption = generated_texts.strip()
|
| 354 |
+
|
| 355 |
+
# Clean up tensors to free GPU memory
|
| 356 |
+
del inputs, generated_ids
|
| 357 |
+
if torch.cuda.is_available():
|
| 358 |
+
torch.cuda.empty_cache()
|
| 359 |
+
|
| 360 |
+
return caption
|
| 361 |
+
|
| 362 |
+
def _inference_internvl35_model(self, image):
|
| 363 |
+
if self.tokenizer is None:
|
| 364 |
+
raise ValueError("Tokenizer must be loaded before generating a caption for InternVL3.5.")
|
| 365 |
+
# image can be numpy (H,W,3) or PIL.Image
|
| 366 |
+
if hasattr(image, "shape"): # numpy array
|
| 367 |
+
pil_image = Image.fromarray(image.astype("uint8"), mode="RGB")
|
| 368 |
+
else:
|
| 369 |
+
pil_image = image
|
| 370 |
+
|
| 371 |
+
pixel_values = self._image_to_pixel_values(pil_image, size=448, max_num=12)
|
| 372 |
+
pixel_values = pixel_values.to(dtype=torch.bfloat16, device=self.model.device)
|
| 373 |
+
|
| 374 |
+
# Format question with image token (matches official docs)
|
| 375 |
+
question = "<image>\n" + self.user_prompt
|
| 376 |
+
|
| 377 |
+
# Generation config matching official examples
|
| 378 |
+
gen_cfg = dict(
|
| 379 |
+
max_new_tokens=128,
|
| 380 |
+
do_sample=False,
|
| 381 |
+
temperature=0.0,
|
| 382 |
+
# Optional: add other parameters from docs
|
| 383 |
+
# top_p=0.9,
|
| 384 |
+
# repetition_penalty=1.1
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# Use model's chat method (official approach)
|
| 388 |
+
response = self.model.chat(self.tokenizer, pixel_values, question, gen_cfg)
|
| 389 |
+
|
| 390 |
+
# Clean up tensors to free GPU memory
|
| 391 |
+
del pixel_values
|
| 392 |
+
if torch.cuda.is_available():
|
| 393 |
+
torch.cuda.empty_cache()
|
| 394 |
+
|
| 395 |
+
return response.strip()
|
| 396 |
+
|
| 397 |
+
def _image_to_pixel_values(self, img, size=448, max_num=12):
|
| 398 |
+
transform = self._build_transform(size)
|
| 399 |
+
tiles = self._dynamic_preprocess(img, image_size=size, max_num=max_num, use_thumbnail=True)
|
| 400 |
+
pixel_values = torch.stack([transform(t) for t in tiles])
|
| 401 |
+
return pixel_values
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def _dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=True):
|
| 405 |
+
# same logic as the model card: split into tiles based on aspect ratio
|
| 406 |
+
w, h = image.size
|
| 407 |
+
aspect = w / h
|
| 408 |
+
targets = sorted({(i, j) for n in range(min_num, max_num+1)
|
| 409 |
+
for i in range(1, n+1) for j in range(1, n+1)
|
| 410 |
+
if i*j <= max_num and i*j >= min_num},
|
| 411 |
+
key=lambda x: x[0]*x[1])
|
| 412 |
+
|
| 413 |
+
# pick closest ratio
|
| 414 |
+
best = min(targets, key=lambda r: abs(aspect - r[0]/r[1]))
|
| 415 |
+
tw, th = image_size * best[0], image_size * best[1]
|
| 416 |
+
resized = image.resize((tw, th))
|
| 417 |
+
|
| 418 |
+
tiles = []
|
| 419 |
+
for i in range(best[0] * best[1]):
|
| 420 |
+
box = ((i % (tw // image_size)) * image_size,
|
| 421 |
+
(i // (tw // image_size)) * image_size,
|
| 422 |
+
((i % (tw // image_size)) + 1) * image_size,
|
| 423 |
+
((i // (tw // image_size)) + 1) * image_size)
|
| 424 |
+
tiles.append(resized.crop(box))
|
| 425 |
+
|
| 426 |
+
if use_thumbnail and len(tiles) != 1:
|
| 427 |
+
tiles.append(image.resize((image_size, image_size)))
|
| 428 |
+
return tiles
|
| 429 |
+
|
| 430 |
+
def _build_transform(self, input_size=448):
|
| 431 |
+
return T.Compose([
|
| 432 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 433 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
| 434 |
+
T.ToTensor(),
|
| 435 |
+
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| 436 |
+
])
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
# Global VLM Manager instance
|
| 440 |
+
vlm_manager = VLMManager()
|
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
absl-py==2.2.2
|
|
|
|
| 2 |
aiofiles==23.2.1
|
| 3 |
aiohappyeyeballs==2.6.1
|
| 4 |
aiohttp==3.11.16
|
|
@@ -8,15 +9,25 @@ anyio==4.9.0
|
|
| 8 |
astunparse==1.6.3
|
| 9 |
async-timeout==5.0.1
|
| 10 |
attrs==25.3.0
|
|
|
|
| 11 |
bcrypt==4.3.0
|
| 12 |
beautifulsoup4==4.13.3
|
|
|
|
|
|
|
|
|
|
| 13 |
certifi==2025.1.31
|
| 14 |
charset-normalizer==3.4.1
|
| 15 |
click==8.1.8
|
|
|
|
|
|
|
| 16 |
cycler==0.12.1
|
|
|
|
| 17 |
datasets==3.5.0
|
|
|
|
| 18 |
deep-translator==1.11.4
|
|
|
|
| 19 |
dill==0.3.8
|
|
|
|
| 20 |
et_xmlfile==2.0.0
|
| 21 |
exceptiongroup==1.2.2
|
| 22 |
fastapi==0.115.12
|
|
@@ -41,16 +52,36 @@ huggingface-hub==0.30.1
|
|
| 41 |
idna==3.10
|
| 42 |
Jinja2==3.1.6
|
| 43 |
keras==3.9.2
|
|
|
|
|
|
|
| 44 |
libclang==18.1.1
|
|
|
|
| 45 |
Markdown==3.8
|
| 46 |
markdown-it-py==3.0.0
|
| 47 |
MarkupSafe==3.0.2
|
| 48 |
mdurl==0.1.2
|
| 49 |
ml_dtypes==0.5.1
|
|
|
|
| 50 |
multidict==6.3.2
|
| 51 |
multiprocess==0.70.16
|
|
|
|
| 52 |
namex==0.0.8
|
|
|
|
| 53 |
numpy==2.1.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
opencv-python==4.11.0.86
|
| 55 |
openpyxl==3.1.5
|
| 56 |
opt_einsum==3.4.0
|
|
@@ -59,54 +90,78 @@ orjson==3.10.16
|
|
| 59 |
packaging==24.2
|
| 60 |
pandas==2.2.3
|
| 61 |
pillow==11.1.0
|
|
|
|
|
|
|
| 62 |
propcache==0.3.1
|
| 63 |
protobuf==5.29.4
|
|
|
|
| 64 |
pyarrow==19.0.1
|
| 65 |
-
pydantic
|
| 66 |
-
pydantic_core
|
| 67 |
pydub==0.25.1
|
| 68 |
Pygments==2.19.1
|
| 69 |
PySocks==1.7.1
|
|
|
|
| 70 |
python-dateutil==2.9.0.post0
|
| 71 |
python-dotenv==1.1.0
|
| 72 |
python-multipart==0.0.20
|
| 73 |
pytz==2025.2
|
|
|
|
| 74 |
PyYAML==6.0.2
|
|
|
|
|
|
|
| 75 |
requests==2.32.3
|
| 76 |
retina-face==0.0.17
|
| 77 |
rich==14.0.0
|
| 78 |
ruff==0.11.4
|
| 79 |
safehttpx==0.1.6
|
|
|
|
| 80 |
semantic-version==2.10.0
|
| 81 |
shellingham==1.5.4
|
| 82 |
six==1.17.0
|
|
|
|
| 83 |
sniffio==1.3.1
|
| 84 |
soupsieve==2.6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
starlette==0.46.1
|
|
|
|
| 86 |
tensorboard==2.19.0
|
| 87 |
tensorboard-data-server==0.7.2
|
| 88 |
tensorflow==2.19.0
|
| 89 |
tensorflow-io-gcs-filesystem==0.37.1
|
| 90 |
termcolor==3.0.1
|
| 91 |
tf_keras==2.19.0
|
|
|
|
|
|
|
|
|
|
| 92 |
tomlkit==0.13.2
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
tqdm==4.67.1
|
|
|
|
|
|
|
| 94 |
typer==0.15.2
|
| 95 |
-
typing-inspection
|
| 96 |
-
typing_extensions
|
| 97 |
tzdata==2025.2
|
|
|
|
| 98 |
urllib3==2.3.0
|
| 99 |
uvicorn==0.34.0
|
|
|
|
|
|
|
| 100 |
websockets==15.0.1
|
| 101 |
Werkzeug==3.1.3
|
| 102 |
wrapt==1.17.2
|
| 103 |
xxhash==3.5.0
|
| 104 |
yarl==1.19.0
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
spacy-legacy==3.0.12
|
| 110 |
-
spacy-loggers==1.0.5
|
| 111 |
-
spacy_thai==0.7.8
|
| 112 |
-
spacy-udpipe==1.0.0
|
|
|
|
| 1 |
absl-py==2.2.2
|
| 2 |
+
accelerate==1.9.0
|
| 3 |
aiofiles==23.2.1
|
| 4 |
aiohappyeyeballs==2.6.1
|
| 5 |
aiohttp==3.11.16
|
|
|
|
| 9 |
astunparse==1.6.3
|
| 10 |
async-timeout==5.0.1
|
| 11 |
attrs==25.3.0
|
| 12 |
+
av==15.1.0
|
| 13 |
bcrypt==4.3.0
|
| 14 |
beautifulsoup4==4.13.3
|
| 15 |
+
bitsandbytes==0.46.1
|
| 16 |
+
blis==1.3.0
|
| 17 |
+
catalogue==2.0.10
|
| 18 |
certifi==2025.1.31
|
| 19 |
charset-normalizer==3.4.1
|
| 20 |
click==8.1.8
|
| 21 |
+
cloudpathlib==0.21.1
|
| 22 |
+
confection==0.1.5
|
| 23 |
cycler==0.12.1
|
| 24 |
+
cymem==2.0.11
|
| 25 |
datasets==3.5.0
|
| 26 |
+
decord==0.6.0
|
| 27 |
deep-translator==1.11.4
|
| 28 |
+
deplacy==2.1.0
|
| 29 |
dill==0.3.8
|
| 30 |
+
einops==0.8.1
|
| 31 |
et_xmlfile==2.0.0
|
| 32 |
exceptiongroup==1.2.2
|
| 33 |
fastapi==0.115.12
|
|
|
|
| 52 |
idna==3.10
|
| 53 |
Jinja2==3.1.6
|
| 54 |
keras==3.9.2
|
| 55 |
+
langcodes==3.5.0
|
| 56 |
+
language_data==1.3.0
|
| 57 |
libclang==18.1.1
|
| 58 |
+
marisa-trie==1.2.1
|
| 59 |
Markdown==3.8
|
| 60 |
markdown-it-py==3.0.0
|
| 61 |
MarkupSafe==3.0.2
|
| 62 |
mdurl==0.1.2
|
| 63 |
ml_dtypes==0.5.1
|
| 64 |
+
mpmath==1.3.0
|
| 65 |
multidict==6.3.2
|
| 66 |
multiprocess==0.70.16
|
| 67 |
+
murmurhash==1.0.13
|
| 68 |
namex==0.0.8
|
| 69 |
+
networkx==3.4.2
|
| 70 |
numpy==2.1.3
|
| 71 |
+
nvidia-cublas-cu12==12.6.4.1
|
| 72 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
| 73 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
| 74 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
| 75 |
+
nvidia-cudnn-cu12==9.5.1.17
|
| 76 |
+
nvidia-cufft-cu12==11.3.0.4
|
| 77 |
+
nvidia-cufile-cu12==1.11.1.6
|
| 78 |
+
nvidia-curand-cu12==10.3.7.77
|
| 79 |
+
nvidia-cusolver-cu12==11.7.1.2
|
| 80 |
+
nvidia-cusparse-cu12==12.5.4.2
|
| 81 |
+
nvidia-cusparselt-cu12==0.6.3
|
| 82 |
+
nvidia-nccl-cu12==2.26.2
|
| 83 |
+
nvidia-nvjitlink-cu12==12.6.85
|
| 84 |
+
nvidia-nvtx-cu12==12.6.77
|
| 85 |
opencv-python==4.11.0.86
|
| 86 |
openpyxl==3.1.5
|
| 87 |
opt_einsum==3.4.0
|
|
|
|
| 90 |
packaging==24.2
|
| 91 |
pandas==2.2.3
|
| 92 |
pillow==11.1.0
|
| 93 |
+
pillow_heif==1.0.0
|
| 94 |
+
preshed==3.0.10
|
| 95 |
propcache==0.3.1
|
| 96 |
protobuf==5.29.4
|
| 97 |
+
psutil==7.0.0
|
| 98 |
pyarrow==19.0.1
|
| 99 |
+
pydantic
|
| 100 |
+
pydantic_core
|
| 101 |
pydub==0.25.1
|
| 102 |
Pygments==2.19.1
|
| 103 |
PySocks==1.7.1
|
| 104 |
+
pythainlp==5.1.2
|
| 105 |
python-dateutil==2.9.0.post0
|
| 106 |
python-dotenv==1.1.0
|
| 107 |
python-multipart==0.0.20
|
| 108 |
pytz==2025.2
|
| 109 |
+
pyuca==1.2
|
| 110 |
PyYAML==6.0.2
|
| 111 |
+
qwen-vl-utils==0.0.8
|
| 112 |
+
regex==2024.11.6
|
| 113 |
requests==2.32.3
|
| 114 |
retina-face==0.0.17
|
| 115 |
rich==14.0.0
|
| 116 |
ruff==0.11.4
|
| 117 |
safehttpx==0.1.6
|
| 118 |
+
safetensors==0.5.3
|
| 119 |
semantic-version==2.10.0
|
| 120 |
shellingham==1.5.4
|
| 121 |
six==1.17.0
|
| 122 |
+
smart_open==7.3.0.post1
|
| 123 |
sniffio==1.3.1
|
| 124 |
soupsieve==2.6
|
| 125 |
+
spacy==3.8.7
|
| 126 |
+
spacy-legacy==3.0.12
|
| 127 |
+
spacy-loggers==1.0.5
|
| 128 |
+
spacy-thai==0.7.8
|
| 129 |
+
spacy-udpipe==1.0.0
|
| 130 |
+
srsly==2.5.1
|
| 131 |
starlette==0.46.1
|
| 132 |
+
sympy==1.14.0
|
| 133 |
tensorboard==2.19.0
|
| 134 |
tensorboard-data-server==0.7.2
|
| 135 |
tensorflow==2.19.0
|
| 136 |
tensorflow-io-gcs-filesystem==0.37.1
|
| 137 |
termcolor==3.0.1
|
| 138 |
tf_keras==2.19.0
|
| 139 |
+
thinc==8.3.6
|
| 140 |
+
timm==1.0.19
|
| 141 |
+
tokenizers==0.21.2
|
| 142 |
tomlkit==0.13.2
|
| 143 |
+
torch==2.7.1
|
| 144 |
+
https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
|
| 145 |
+
torchao==0.13.0
|
| 146 |
+
torchvision==0.22.1
|
| 147 |
tqdm==4.67.1
|
| 148 |
+
transformers==4.53.3
|
| 149 |
+
triton==3.3.1
|
| 150 |
typer==0.15.2
|
| 151 |
+
typing-inspection
|
| 152 |
+
typing_extensions
|
| 153 |
tzdata==2025.2
|
| 154 |
+
ufal.udpipe==1.3.1.1
|
| 155 |
urllib3==2.3.0
|
| 156 |
uvicorn==0.34.0
|
| 157 |
+
wasabi==1.1.3
|
| 158 |
+
weasel==0.4.1
|
| 159 |
websockets==15.0.1
|
| 160 |
Werkzeug==3.1.3
|
| 161 |
wrapt==1.17.2
|
| 162 |
xxhash==3.5.0
|
| 163 |
yarl==1.19.0
|
| 164 |
+
supabase==2.18.1
|
| 165 |
+
supabase_auth==2.12.3
|
| 166 |
+
supabase_functions==0.10.1
|
| 167 |
+
# flash_attn==2.8.1
|
|
|
|
|
|
|
|
|
|
|
|
ui/layout.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import time
|
|
|
|
| 3 |
|
| 4 |
from logic.data_utils import CustomHFDatasetSaver
|
| 5 |
from data.lang2eng_map import lang2eng_mapping
|
|
@@ -12,6 +13,161 @@ from .selection_page import build_selection_page
|
|
| 12 |
from .main_page import build_main_page
|
| 13 |
from .main_page import sort_with_pyuca
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def get_key_by_value(dictionary, value):
|
| 16 |
for key, val in dictionary.items():
|
| 17 |
if val == value:
|
|
@@ -100,14 +256,27 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
| 100 |
object-fit: contain; /* make sure the full image shows */
|
| 101 |
height: 460px; /* set a fixed height */
|
| 102 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
"""
|
| 104 |
############################################################################
|
| 105 |
with gr.Blocks(css=custom_css) as ui:
|
|
|
|
|
|
|
|
|
|
| 106 |
local_storage = gr.State([None, None, "", ""])
|
| 107 |
loading_example = gr.State(False) # to check if the values are loaded from a user click on an example in
|
| 108 |
# First page: selection
|
| 109 |
|
| 110 |
-
selection_page, country_choice, language_choice, proceed_btn, username, password, intro_markdown = build_selection_page(metadata_dict)
|
| 111 |
|
| 112 |
# Second page
|
| 113 |
cmp_main_ui = build_main_page(concepts_dict, metadata_dict, local_storage)
|
|
@@ -144,8 +313,20 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
| 144 |
modal_exclude_confirm = cmp_main_ui["modal_exclude_confirm"]
|
| 145 |
cancel_exclude_btn = cmp_main_ui["cancel_exclude_btn"]
|
| 146 |
confirm_exclude_btn = cmp_main_ui["confirm_exclude_btn"]
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
### Category button
|
| 150 |
category_btn.change(
|
| 151 |
fn=partial(load_concepts, concepts=concepts_dict),
|
|
@@ -214,7 +395,7 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
| 214 |
clear_btn.click(
|
| 215 |
fn=clear_data,
|
| 216 |
outputs=[
|
| 217 |
-
image_inp, image_url_inp, long_caption_inp, exampleid_btn,
|
| 218 |
category_btn, concept_btn,
|
| 219 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
| 220 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
|
@@ -280,12 +461,12 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
| 280 |
# Handle clicking on an example
|
| 281 |
user_examples.click(
|
| 282 |
fn=partial(handle_click_example, concepts_dict=concepts_dict),
|
| 283 |
-
inputs=[user_examples],
|
| 284 |
outputs=[
|
| 285 |
image_inp, image_url_inp, long_caption_inp, exampleid_btn,
|
| 286 |
category_btn, concept_btn,
|
| 287 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
| 288 |
-
category_concept_dropdowns[3], category_concept_dropdowns[4], loading_example
|
| 289 |
],
|
| 290 |
)
|
| 291 |
|
|
@@ -295,6 +476,41 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
| 295 |
|
| 296 |
# ============================================ #
|
| 297 |
# Submit Button Click events
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
proceed_btn.click(
|
| 300 |
fn=partial(switch_ui, flag=False),
|
|
@@ -313,8 +529,8 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
| 313 |
]
|
| 314 |
).then(
|
| 315 |
fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path = LOCAL_DS_DIRECTORY_PATH),
|
| 316 |
-
inputs=[
|
| 317 |
-
outputs=[user_examples, loading_msg],
|
| 318 |
)
|
| 319 |
|
| 320 |
|
|
@@ -322,7 +538,7 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
| 322 |
exit_btn.click(
|
| 323 |
fn=exit,
|
| 324 |
outputs=[
|
| 325 |
-
image_inp, image_url_inp, long_caption_inp, user_examples, loading_msg,
|
| 326 |
username, password, local_storage, exampleid_btn, category_btn, concept_btn,
|
| 327 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
| 328 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
|
@@ -368,7 +584,10 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
| 368 |
"excluded": gr.State(value=False),
|
| 369 |
"concepts_dict": gr.State(value=concepts_dict),
|
| 370 |
"country_lang_map": gr.State(value=lang2eng_mapping),
|
|
|
|
| 371 |
# "is_blurred": is_blurred
|
|
|
|
|
|
|
| 372 |
}
|
| 373 |
# data_outputs = [image_inp, image_url_inp, long_caption_inp,
|
| 374 |
# country_inp, language_inp, category_btn, concept_btn,
|
|
@@ -376,34 +595,56 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
| 376 |
hf_writer.setup(list(data_outputs.keys()), local_ds_folder = LOCAL_DS_DIRECTORY_PATH)
|
| 377 |
|
| 378 |
# STEP 4: Chain save_data, then update_user_data, then re-enable button, hide modal, and clear
|
| 379 |
-
submit_btn.click(
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
# ============================================ #
|
| 408 |
# instructions button
|
| 409 |
instruct_btn.click(lambda: Modal(visible=True), None, modal)
|
|
@@ -446,13 +687,13 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
| 446 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
| 447 |
category_concept_dropdowns[3], category_concept_dropdowns[4],
|
| 448 |
timestamp_btn, username_inp, password_inp, exampleid_btn, gr.State(value=True),
|
| 449 |
-
gr.State(value=concepts_dict), gr.State(value=lang2eng_mapping)
|
| 450 |
],
|
| 451 |
outputs=None
|
| 452 |
).success(
|
| 453 |
fn=partial(clear_data, "remove"),
|
| 454 |
-
outputs=[
|
| 455 |
-
image_inp, image_url_inp, long_caption_inp, exampleid_btn,
|
| 456 |
category_btn, concept_btn,
|
| 457 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
| 458 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
|
@@ -465,8 +706,32 @@ def build_ui(concepts_dict, metadata_dict, HF_API_TOKEN, HF_DATASET_NAME):
|
|
| 465 |
outputs=loading_msg
|
| 466 |
).success(
|
| 467 |
fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path=LOCAL_DS_DIRECTORY_PATH),
|
| 468 |
-
inputs=[
|
| 469 |
-
outputs=[user_examples, loading_msg]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
)
|
|
|
|
| 471 |
|
| 472 |
-
return ui
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import time
|
| 3 |
+
from logic.supabase_client import auth_handler
|
| 4 |
|
| 5 |
from logic.data_utils import CustomHFDatasetSaver
|
| 6 |
from data.lang2eng_map import lang2eng_mapping
|
|
|
|
| 13 |
from .main_page import build_main_page
|
| 14 |
from .main_page import sort_with_pyuca
|
| 15 |
|
| 16 |
+
js_code = """
|
| 17 |
+
function() {
|
| 18 |
+
// Get the full URL with the fragment
|
| 19 |
+
const url = window.location.href;
|
| 20 |
+
const fragment = url.split('#')[1];
|
| 21 |
+
|
| 22 |
+
if (!fragment) {
|
| 23 |
+
return "";
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
// Parse the fragment into an object
|
| 27 |
+
const params = new URLSearchParams(fragment);
|
| 28 |
+
const access_token = params.get('access_token');
|
| 29 |
+
const refresh_token = params.get('refresh_token');
|
| 30 |
+
|
| 31 |
+
// Create a JSON string with the tokens
|
| 32 |
+
const tokens = JSON.stringify({
|
| 33 |
+
access_token: access_token,
|
| 34 |
+
refresh_token: refresh_token
|
| 35 |
+
});
|
| 36 |
+
|
| 37 |
+
// Return the JSON string to the Gradio output component
|
| 38 |
+
return tokens;
|
| 39 |
+
}
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def login_user(email, password):
|
| 43 |
+
result = auth_handler.login(email, password)
|
| 44 |
+
if result['success']:
|
| 45 |
+
session_data = result['data']
|
| 46 |
+
persistent_data = {
|
| 47 |
+
"refresh_token": session_data['refresh_token'],
|
| 48 |
+
"user_email": session_data['user_email']
|
| 49 |
+
}
|
| 50 |
+
return session_data['client'], persistent_data, result['message']
|
| 51 |
+
else:
|
| 52 |
+
persistent_data = {
|
| 53 |
+
"refresh_token": "",
|
| 54 |
+
"user_email": ""
|
| 55 |
+
}
|
| 56 |
+
return None, persistent_data, result['message']
|
| 57 |
+
|
| 58 |
+
def login_user_recovery(session_data: str):
|
| 59 |
+
"""
|
| 60 |
+
This function receives session data (tokens as a JSON string) from the frontend,
|
| 61 |
+
retrieves the session, and returns data in a format similar to login_user.
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
import json
|
| 65 |
+
tokens = json.loads(session_data)
|
| 66 |
+
access_token = tokens.get("access_token")
|
| 67 |
+
refresh_token = tokens.get("refresh_token")
|
| 68 |
+
|
| 69 |
+
if not access_token or not refresh_token:
|
| 70 |
+
return None, gr.skip(), "Invalid session data provided."
|
| 71 |
+
|
| 72 |
+
result = auth_handler.retrieve_session_from_tokens(access_token, refresh_token)
|
| 73 |
+
|
| 74 |
+
if result['success']:
|
| 75 |
+
session_data_result = result['data']
|
| 76 |
+
persistent_data = {
|
| 77 |
+
"refresh_token": session_data_result['refresh_token'],
|
| 78 |
+
"user_email": session_data_result['user_email']
|
| 79 |
+
}
|
| 80 |
+
return session_data_result['client'], persistent_data, result['message']
|
| 81 |
+
else:
|
| 82 |
+
persistent_data = {
|
| 83 |
+
"refresh_token": "",
|
| 84 |
+
"user_email": ""
|
| 85 |
+
}
|
| 86 |
+
return None, persistent_data, result['message']
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
return None, gr.skip(), f"Failed to process recovery login: {e}"
|
| 90 |
+
|
| 91 |
+
def sign_up(email, password):
|
| 92 |
+
result = auth_handler.sign_up(email, password)
|
| 93 |
+
return result['message']
|
| 94 |
+
|
| 95 |
+
def reset_password(email):
|
| 96 |
+
result = auth_handler.reset_password_for_email(email)
|
| 97 |
+
return result['message']
|
| 98 |
+
|
| 99 |
+
def log_out(supabase_user_client, persistent_session):
|
| 100 |
+
"""
|
| 101 |
+
Logs out the user and clears the session. If error occurs, it returns an empty persistent session (logging out user).
|
| 102 |
+
"""
|
| 103 |
+
persistent_session = {
|
| 104 |
+
"refresh_token": "",
|
| 105 |
+
"user_email": ""
|
| 106 |
+
}
|
| 107 |
+
if supabase_user_client:
|
| 108 |
+
result = auth_handler.logout(supabase_user_client)
|
| 109 |
+
if result['success']:
|
| 110 |
+
print("User logged out successfully.")
|
| 111 |
+
return persistent_session
|
| 112 |
+
else:
|
| 113 |
+
print(f"Error logging out: {result['message']}")
|
| 114 |
+
return persistent_session
|
| 115 |
+
else:
|
| 116 |
+
print("No user client provided to log out.")
|
| 117 |
+
return persistent_session
|
| 118 |
+
|
| 119 |
+
def restore_user_session(session_data, login_status=None):
|
| 120 |
+
print("Restoring user session with data:", session_data)
|
| 121 |
+
# defualt values if the user is not logged in
|
| 122 |
+
# or the session data is not valid
|
| 123 |
+
login_status_update = gr.update(value= login_status if login_status else "")
|
| 124 |
+
proceed_button_update = gr.update(value="Proceed as Anonymous User", interactive=True)
|
| 125 |
+
login_button_update = gr.update(visible=True)
|
| 126 |
+
sign_up_button_update = gr.update(visible=True)
|
| 127 |
+
reset_password_button_update = gr.update(visible=True)
|
| 128 |
+
logout_button_update = gr.update(visible=False)
|
| 129 |
+
change_password_field_update = gr.update(visible=False)
|
| 130 |
+
change_password_field_confirm_update = gr.update(visible=False)
|
| 131 |
+
change_password_button_update = gr.update(visible=False)
|
| 132 |
+
change_password_status_update = gr.update(value="")
|
| 133 |
+
persistent_data = {
|
| 134 |
+
"refresh_token": "",
|
| 135 |
+
"user_email": ""
|
| 136 |
+
}
|
| 137 |
+
if not session_data or not session_data.get('refresh_token', ''):
|
| 138 |
+
print("No session data found, proceeding as anonymous user.")
|
| 139 |
+
return None, persistent_data, login_status_update, proceed_button_update, login_button_update, sign_up_button_update, reset_password_button_update, logout_button_update, change_password_field_update, change_password_field_confirm_update, change_password_button_update, change_password_status_update
|
| 140 |
+
|
| 141 |
+
result = auth_handler.restore_session(session_data['refresh_token'])
|
| 142 |
+
if result['success']:
|
| 143 |
+
restored_session = result['data']
|
| 144 |
+
new_persistent_data = {
|
| 145 |
+
"refresh_token": restored_session['refresh_token'],
|
| 146 |
+
"user_email": restored_session['user_email']
|
| 147 |
+
}
|
| 148 |
+
login_status_update = gr.update(value=result['message'])
|
| 149 |
+
proceed_button_update = gr.update(value="Proceed", interactive=True)
|
| 150 |
+
login_button_update = gr.update(visible=False)
|
| 151 |
+
sign_up_button_update = gr.update(visible=False)
|
| 152 |
+
reset_password_button_update = gr.update(visible=False)
|
| 153 |
+
logout_button_update = gr.update(visible=True)
|
| 154 |
+
change_password_field_update = gr.update(visible=True)
|
| 155 |
+
change_password_field_confirm_update = gr.update(visible=True)
|
| 156 |
+
change_password_button_update = gr.update(visible=True)
|
| 157 |
+
return restored_session['client'], new_persistent_data, login_status_update, proceed_button_update, login_button_update, sign_up_button_update, reset_password_button_update, logout_button_update, change_password_field_update, change_password_field_confirm_update, change_password_button_update, change_password_status_update
|
| 158 |
+
else:
|
| 159 |
+
return None, persistent_data, login_status_update, proceed_button_update, login_button_update, sign_up_button_update, reset_password_button_update, logout_button_update, change_password_field_update, change_password_field_confirm_update, change_password_button_update, change_password_status_update
|
| 160 |
+
|
| 161 |
+
def change_password(supabase_user_client, new_password, confirm_password):
|
| 162 |
+
"""
|
| 163 |
+
Changes the user's password.
|
| 164 |
+
"""
|
| 165 |
+
if new_password != confirm_password:
|
| 166 |
+
return "Passwords do not match. Please try again."
|
| 167 |
+
result = auth_handler.change_password(supabase_user_client, new_password)
|
| 168 |
+
return result['message']
|
| 169 |
+
|
| 170 |
+
|
| 171 |
def get_key_by_value(dictionary, value):
|
| 172 |
for key, val in dictionary.items():
|
| 173 |
if val == value:
|
|
|
|
| 256 |
object-fit: contain; /* make sure the full image shows */
|
| 257 |
height: 460px; /* set a fixed height */
|
| 258 |
}
|
| 259 |
+
#vlm_output .input-container {
|
| 260 |
+
position: relative;
|
| 261 |
+
}
|
| 262 |
+
#vlm_output .input-container::before {
|
| 263 |
+
content: "";
|
| 264 |
+
position: absolute;
|
| 265 |
+
top: 0; left: 0; right: 0; bottom: 0;
|
| 266 |
+
z-index: 10; /* sits above the textarea */
|
| 267 |
+
background: transparent;
|
| 268 |
+
}
|
| 269 |
"""
|
| 270 |
############################################################################
|
| 271 |
with gr.Blocks(css=custom_css) as ui:
|
| 272 |
+
supabase_user_client = gr.State(None)
|
| 273 |
+
persistent_session = gr.BrowserState(None)
|
| 274 |
+
|
| 275 |
local_storage = gr.State([None, None, "", ""])
|
| 276 |
loading_example = gr.State(False) # to check if the values are loaded from a user click on an example in
|
| 277 |
# First page: selection
|
| 278 |
|
| 279 |
+
selection_page, country_choice, language_choice, proceed_btn, username, password, intro_markdown, login_btn, sign_up_btn, reset_password_btn, login_status, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status = build_selection_page(metadata_dict)
|
| 280 |
|
| 281 |
# Second page
|
| 282 |
cmp_main_ui = build_main_page(concepts_dict, metadata_dict, local_storage)
|
|
|
|
| 313 |
modal_exclude_confirm = cmp_main_ui["modal_exclude_confirm"]
|
| 314 |
cancel_exclude_btn = cmp_main_ui["cancel_exclude_btn"]
|
| 315 |
confirm_exclude_btn = cmp_main_ui["confirm_exclude_btn"]
|
| 316 |
+
vlm_output = cmp_main_ui["vlm_output"]
|
| 317 |
+
gen_button = cmp_main_ui["gen_button"]
|
| 318 |
+
vlm_feedback = cmp_main_ui["vlm_feedback"]
|
| 319 |
+
modal_vlm = cmp_main_ui["modal_vlm"]
|
| 320 |
+
vlm_no_btn = cmp_main_ui["vlm_no_btn"]
|
| 321 |
+
vlm_done_btn = cmp_main_ui["vlm_done_btn"]
|
| 322 |
+
submit_yes = cmp_main_ui["submit_yes"]
|
| 323 |
+
submit_no = cmp_main_ui["submit_no"]
|
| 324 |
+
modal_submit = cmp_main_ui["modal_submit"]
|
| 325 |
+
vlm_cancel_btn = cmp_main_ui["vlm_cancel_btn"]
|
| 326 |
+
vlm_model_dropdown = cmp_main_ui["vlm_model_dropdown"]
|
| 327 |
+
|
| 328 |
+
# dictionary to store all vlm_output by exampleid
|
| 329 |
+
vlm_captions = gr.State(None)
|
| 330 |
### Category button
|
| 331 |
category_btn.change(
|
| 332 |
fn=partial(load_concepts, concepts=concepts_dict),
|
|
|
|
| 395 |
clear_btn.click(
|
| 396 |
fn=clear_data,
|
| 397 |
outputs=[
|
| 398 |
+
image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, exampleid_btn,
|
| 399 |
category_btn, concept_btn,
|
| 400 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
| 401 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
|
|
|
| 461 |
# Handle clicking on an example
|
| 462 |
user_examples.click(
|
| 463 |
fn=partial(handle_click_example, concepts_dict=concepts_dict),
|
| 464 |
+
inputs=[user_examples, vlm_captions],
|
| 465 |
outputs=[
|
| 466 |
image_inp, image_url_inp, long_caption_inp, exampleid_btn,
|
| 467 |
category_btn, concept_btn,
|
| 468 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
| 469 |
+
category_concept_dropdowns[3], category_concept_dropdowns[4], loading_example, vlm_output
|
| 470 |
],
|
| 471 |
)
|
| 472 |
|
|
|
|
| 476 |
|
| 477 |
# ============================================ #
|
| 478 |
# Submit Button Click events
|
| 479 |
+
login_btn.click(
|
| 480 |
+
fn=login_user,
|
| 481 |
+
inputs=[username, password],
|
| 482 |
+
outputs=[supabase_user_client, persistent_session, login_status],
|
| 483 |
+
).then(
|
| 484 |
+
fn=restore_user_session,
|
| 485 |
+
inputs=[persistent_session, login_status],
|
| 486 |
+
outputs=[supabase_user_client, persistent_session, login_status, proceed_btn, login_btn, sign_up_btn, reset_password_btn, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status],
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
sign_up_btn.click(
|
| 490 |
+
fn=sign_up,
|
| 491 |
+
inputs=[username, password],
|
| 492 |
+
outputs=[login_status],
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
logout_btn.click(
|
| 496 |
+
fn=log_out,
|
| 497 |
+
inputs=[supabase_user_client, persistent_session],
|
| 498 |
+
outputs=[persistent_session]
|
| 499 |
+
).then(
|
| 500 |
+
fn=restore_user_session,
|
| 501 |
+
inputs=[persistent_session],
|
| 502 |
+
outputs=[supabase_user_client, persistent_session, login_status, proceed_btn, login_btn, sign_up_btn, reset_password_btn, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status],
|
| 503 |
+
)
|
| 504 |
+
change_password_btn.click(
|
| 505 |
+
fn=change_password,
|
| 506 |
+
inputs=[supabase_user_client, change_password_field, change_password_field_confirm],
|
| 507 |
+
outputs=[change_password_status]
|
| 508 |
+
)
|
| 509 |
+
reset_password_btn.click(
|
| 510 |
+
fn=reset_password,
|
| 511 |
+
inputs=[username],
|
| 512 |
+
outputs=[login_status]
|
| 513 |
+
)
|
| 514 |
|
| 515 |
proceed_btn.click(
|
| 516 |
fn=partial(switch_ui, flag=False),
|
|
|
|
| 529 |
]
|
| 530 |
).then(
|
| 531 |
fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path = LOCAL_DS_DIRECTORY_PATH),
|
| 532 |
+
inputs=[supabase_user_client, country_choice, language_choice],
|
| 533 |
+
outputs=[user_examples, loading_msg, vlm_captions],
|
| 534 |
)
|
| 535 |
|
| 536 |
|
|
|
|
| 538 |
exit_btn.click(
|
| 539 |
fn=exit,
|
| 540 |
outputs=[
|
| 541 |
+
image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, user_examples, loading_msg,
|
| 542 |
username, password, local_storage, exampleid_btn, category_btn, concept_btn,
|
| 543 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
| 544 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
|
|
|
| 584 |
"excluded": gr.State(value=False),
|
| 585 |
"concepts_dict": gr.State(value=concepts_dict),
|
| 586 |
"country_lang_map": gr.State(value=lang2eng_mapping),
|
| 587 |
+
"client": supabase_user_client,
|
| 588 |
# "is_blurred": is_blurred
|
| 589 |
+
"vlm_caption": vlm_output,
|
| 590 |
+
"vlm_feedback": vlm_feedback
|
| 591 |
}
|
| 592 |
# data_outputs = [image_inp, image_url_inp, long_caption_inp,
|
| 593 |
# country_inp, language_inp, category_btn, concept_btn,
|
|
|
|
| 595 |
hf_writer.setup(list(data_outputs.keys()), local_ds_folder = LOCAL_DS_DIRECTORY_PATH)
|
| 596 |
|
| 597 |
# STEP 4: Chain save_data, then update_user_data, then re-enable button, hide modal, and clear
|
| 598 |
+
# submit_btn.click(lambda: Modal(visible=True), None, modal_vlm)
|
| 599 |
+
submit_btn.click(submit_button_clicked,
|
| 600 |
+
inputs=[vlm_output],
|
| 601 |
+
outputs=[modal_vlm, modal_submit])
|
| 602 |
+
|
| 603 |
+
# submit_btn.click(partial(submit_button_clicked, save_fn=hf_writer.save,
|
| 604 |
+
# data_outputs=data_outputs),
|
| 605 |
+
# inputs=[vlm_output],
|
| 606 |
+
# outputs=[modal_vlm, image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, exampleid_btn,
|
| 607 |
+
# category_btn, concept_btn, category_concept_dropdowns[0], category_concept_dropdowns[1],
|
| 608 |
+
# category_concept_dropdowns[2], category_concept_dropdowns[3], category_concept_dropdowns[4]])
|
| 609 |
+
|
| 610 |
+
def wire_submit_chain(button, modal_ui):
|
| 611 |
+
e = button.click(
|
| 612 |
+
fn=lambda: Modal(visible=False),
|
| 613 |
+
outputs=[modal_ui]
|
| 614 |
+
).success(
|
| 615 |
+
hf_writer.save,
|
| 616 |
+
inputs = list(data_outputs.values()),
|
| 617 |
+
outputs = None,
|
| 618 |
+
).success(
|
| 619 |
+
fn=partial(clear_data, "submit"),
|
| 620 |
+
outputs=[
|
| 621 |
+
image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, exampleid_btn,
|
| 622 |
+
category_btn, concept_btn,
|
| 623 |
+
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
| 624 |
+
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
| 625 |
+
],
|
| 626 |
+
# ).success(enable_submit,
|
| 627 |
+
# None, [submit_btn]
|
| 628 |
+
# ).success(lambda: Modal(visible=False),
|
| 629 |
+
# None, modal_saving
|
| 630 |
+
# ).success(lambda: Modal(visible=True),
|
| 631 |
+
# None, modal_data_saved
|
| 632 |
+
).success(
|
| 633 |
+
# set loading msg
|
| 634 |
+
lambda: gr.update(value="**Loading your data, please wait ...**"),
|
| 635 |
+
None, loading_msg
|
| 636 |
+
).success(
|
| 637 |
+
fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path = LOCAL_DS_DIRECTORY_PATH),
|
| 638 |
+
inputs=[supabase_user_client, country_choice, language_choice],
|
| 639 |
+
outputs=[user_examples, loading_msg, vlm_captions]
|
| 640 |
+
)
|
| 641 |
+
return e
|
| 642 |
+
|
| 643 |
+
wire_submit_chain(vlm_done_btn, modal_vlm)
|
| 644 |
+
wire_submit_chain(vlm_no_btn, modal_vlm)
|
| 645 |
+
wire_submit_chain(submit_yes, modal_submit)
|
| 646 |
+
submit_no.click(lambda: Modal(visible=False), None, modal_submit)
|
| 647 |
+
vlm_cancel_btn.click(lambda: Modal(visible=False), None, modal_vlm)
|
| 648 |
# ============================================ #
|
| 649 |
# instructions button
|
| 650 |
instruct_btn.click(lambda: Modal(visible=True), None, modal)
|
|
|
|
| 687 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
| 688 |
category_concept_dropdowns[3], category_concept_dropdowns[4],
|
| 689 |
timestamp_btn, username_inp, password_inp, exampleid_btn, gr.State(value=True),
|
| 690 |
+
gr.State(value=concepts_dict), gr.State(value=lang2eng_mapping), vlm_output, vlm_feedback
|
| 691 |
],
|
| 692 |
outputs=None
|
| 693 |
).success(
|
| 694 |
fn=partial(clear_data, "remove"),
|
| 695 |
+
outputs=[
|
| 696 |
+
image_inp, image_url_inp, long_caption_inp, vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button, exampleid_btn,
|
| 697 |
category_btn, concept_btn,
|
| 698 |
category_concept_dropdowns[0], category_concept_dropdowns[1], category_concept_dropdowns[2],
|
| 699 |
category_concept_dropdowns[3], category_concept_dropdowns[4]
|
|
|
|
| 706 |
outputs=loading_msg
|
| 707 |
).success(
|
| 708 |
fn=partial(update_user_data, HF_DATASET_NAME=HF_DATASET_NAME, local_ds_directory_path=LOCAL_DS_DIRECTORY_PATH),
|
| 709 |
+
inputs=[supabase_user_client, country_choice, language_choice],
|
| 710 |
+
outputs=[user_examples, loading_msg, vlm_captions]
|
| 711 |
+
)
|
| 712 |
+
# ============================================= #
|
| 713 |
+
# VLM Gen button
|
| 714 |
+
# ============================================= #
|
| 715 |
+
gen_button.click(
|
| 716 |
+
fn=generate_vlm_caption, # processor=processor, model=model
|
| 717 |
+
inputs=[image_inp, vlm_model_dropdown],
|
| 718 |
+
outputs=[vlm_output, vlm_feedback, vlm_done_btn, vlm_no_btn, gen_button]
|
| 719 |
+
)
|
| 720 |
+
# vlm_output.change(
|
| 721 |
+
# fn=lambda : gr.update(interactive=False) if vlm_output.value else gr.update(interactive=True),
|
| 722 |
+
# inputs=[],
|
| 723 |
+
# outputs=[gen_button]
|
| 724 |
+
# )
|
| 725 |
+
|
| 726 |
+
ui.load(
|
| 727 |
+
fn=login_user_recovery,
|
| 728 |
+
inputs=gr.Textbox(visible=False, value=""), # hidden textbox to get the url tokens
|
| 729 |
+
outputs=[supabase_user_client, persistent_session, login_status],
|
| 730 |
+
js=js_code
|
| 731 |
+
).then(
|
| 732 |
+
fn=restore_user_session,
|
| 733 |
+
inputs=[persistent_session],
|
| 734 |
+
outputs=[supabase_user_client, persistent_session, login_status, proceed_btn, login_btn, sign_up_btn, reset_password_btn, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status],
|
| 735 |
)
|
| 736 |
+
return ui
|
| 737 |
|
|
|
ui/main_page.py
CHANGED
|
@@ -107,7 +107,36 @@ def build_main_page(concepts_dict, metadata_dict, local_storage):
|
|
| 107 |
long_caption_inp = gr.Textbox(lines=6, label="Description", elem_id="long_caption_inp")
|
| 108 |
num_words_inp = gr.Textbox(lines=1, label="Number of words", elem_id="num_words", interactive=False, value=0)
|
| 109 |
# num_words_inp = gr.Markdown("Number of words", elem_id="num_words")
|
|
|
|
|
|
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
categories_list = sort_with_pyuca(list(concepts_dict["USA"]["English"].keys()))
|
| 112 |
|
| 113 |
def create_category_dropdown(category, index):
|
|
@@ -226,5 +255,16 @@ def build_main_page(concepts_dict, metadata_dict, local_storage):
|
|
| 226 |
"modal_exclude_confirm": modal_exclude_confirm,
|
| 227 |
"cancel_exclude_btn": cancel_exclude_btn,
|
| 228 |
"confirm_exclude_btn": confirm_exclude_btn,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
}
|
| 230 |
return output_dict
|
|
|
|
| 107 |
long_caption_inp = gr.Textbox(lines=6, label="Description", elem_id="long_caption_inp")
|
| 108 |
num_words_inp = gr.Textbox(lines=1, label="Number of words", elem_id="num_words", interactive=False, value=0)
|
| 109 |
# num_words_inp = gr.Markdown("Number of words", elem_id="num_words")
|
| 110 |
+
#########################################################
|
| 111 |
+
with Modal(visible=False, allow_user_close=False) as modal_vlm:
|
| 112 |
|
| 113 |
+
question = gr.Markdown("Would you like to see if a VLM can generate a culturally aware description for your uploaded concept?")
|
| 114 |
+
with gr.Row():
|
| 115 |
+
gen_button = gr.Button("Yes", variant="primary", elem_id="generate_answer_btn")
|
| 116 |
+
vlm_no_btn = gr.Button("No")
|
| 117 |
+
vlm_cancel_btn = gr.Button("Cancel")
|
| 118 |
+
vlm_model_dropdown = gr.Dropdown(
|
| 119 |
+
["SmolVLM-500M", "Qwen2.5-VL-7B", "InternVL3_5-8B", "Gemma3-4B"], value="Gemma3-4B", multiselect=False, label="VLM Model", info="Select the VLM model to use for generating the description."
|
| 120 |
+
)
|
| 121 |
+
vlm_output = gr.Textbox(lines=6, label="Generated description", elem_id="vlm_output", interactive=False)
|
| 122 |
+
vlm_feedback = gr.Radio(["Yes π", "No π"], label="Do you think the generated description is accurate within the cultural context of your country?", visible=False, elem_id="vlm_feedback", interactive=True)
|
| 123 |
+
vlm_done_btn = gr.Button("Complete Submission", visible=False)
|
| 124 |
+
|
| 125 |
+
with Modal(visible=False, allow_user_close=False) as modal_submit:
|
| 126 |
+
|
| 127 |
+
gr.Markdown("β οΈ You've already generated a caption for this image. An optional description with the VLM can only be generated once. Would you like to proceed and submit your modified data?")
|
| 128 |
+
with gr.Row():
|
| 129 |
+
submit_yes = gr.Button("Yes", variant="primary", elem_id="submit_confirm_yes")
|
| 130 |
+
submit_no = gr.Button("No", variant="stop", elem_id="submit_confirm_no")
|
| 131 |
+
|
| 132 |
+
# with gr.Group():
|
| 133 |
+
# gr.Markdown("### VLM Generation (Optional)")
|
| 134 |
+
# with gr.Accordion("π Click here if you want to get a generated answer from a small vlm", open=False):
|
| 135 |
+
# gen_button = gr.Button("Generate Answer", variant="primary", elem_id="generate_answer_btn")
|
| 136 |
+
# vlm_output = gr.Textbox(lines=6, label="Generated Answer", elem_id="vlm_output", interactive=False)
|
| 137 |
+
# vlm_feedback = gr.Radio(["Yes π", "No π"], label="Do you like the generated caption?", visible=False, elem_id="vlm_feedback", interactive=True)
|
| 138 |
+
##########################################################
|
| 139 |
+
|
| 140 |
categories_list = sort_with_pyuca(list(concepts_dict["USA"]["English"].keys()))
|
| 141 |
|
| 142 |
def create_category_dropdown(category, index):
|
|
|
|
| 255 |
"modal_exclude_confirm": modal_exclude_confirm,
|
| 256 |
"cancel_exclude_btn": cancel_exclude_btn,
|
| 257 |
"confirm_exclude_btn": confirm_exclude_btn,
|
| 258 |
+
"vlm_output": vlm_output,
|
| 259 |
+
"gen_button": gen_button,
|
| 260 |
+
"vlm_feedback": vlm_feedback,
|
| 261 |
+
"modal_vlm": modal_vlm,
|
| 262 |
+
"vlm_no_btn": vlm_no_btn,
|
| 263 |
+
"vlm_done_btn": vlm_done_btn,
|
| 264 |
+
"submit_yes": submit_yes,
|
| 265 |
+
"submit_no": submit_no,
|
| 266 |
+
"modal_submit": modal_submit,
|
| 267 |
+
"vlm_cancel_btn": vlm_cancel_btn,
|
| 268 |
+
"vlm_model_dropdown": vlm_model_dropdown
|
| 269 |
}
|
| 270 |
return output_dict
|
ui/selection_page.py
CHANGED
|
@@ -57,6 +57,24 @@ def build_selection_page(metadata_dict):
|
|
| 57 |
username = gr.Textbox(label="Email (optional)", type="email", elem_id="username_text")
|
| 58 |
password = gr.Textbox(label="Password (optional)", type="password", elem_id="password_text")
|
| 59 |
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
username = gr.Textbox(label="Email (optional)", type="email", elem_id="username_text")
|
| 58 |
password = gr.Textbox(label="Password (optional)", type="password", elem_id="password_text")
|
| 59 |
|
| 60 |
+
with gr.Row():
|
| 61 |
+
login_btn = gr.Button("Login", elem_id="login_btn")
|
| 62 |
+
sign_up_btn = gr.Button("Sign up", elem_id="sign_up_btn")
|
| 63 |
+
reset_password_btn = gr.Button("Reset Password", elem_id="reset_password_btn")
|
| 64 |
+
logout_btn = gr.Button("Logout", elem_id="logout_btn",visible=False)
|
| 65 |
|
| 66 |
+
login_status = gr.Markdown("")
|
| 67 |
+
with gr.Row():
|
| 68 |
+
proceed_btn = gr.Button("Proceed")
|
| 69 |
+
with gr.Row():
|
| 70 |
+
change_password_field = gr.Textbox(
|
| 71 |
+
label="Change Password", type="password", elem_id="change_password_field", visible=True
|
| 72 |
+
)
|
| 73 |
+
change_password_field_confirm = gr.Textbox(
|
| 74 |
+
label="Confirm New Password", type="password", elem_id="change_password_field_confirm", visible=True
|
| 75 |
+
)
|
| 76 |
+
with gr.Row():
|
| 77 |
+
change_password_btn = gr.Button("Change Password", elem_id="change_password_btn", visible=True)
|
| 78 |
+
change_password_status = gr.Markdown("")
|
| 79 |
+
|
| 80 |
+
return selection_page, country_choice, language_choice, proceed_btn, username, password, intro_markdown, login_btn, sign_up_btn, reset_password_btn, login_status, logout_btn, change_password_field, change_password_field_confirm, change_password_btn, change_password_status
|