Spaces:
Running
on
T4
Running
on
T4
| import os.path | |
| import datetime | |
| import io | |
| import PIL | |
| import requests | |
| from datasets import load_dataset, concatenate_datasets, Image | |
| from data.lang2eng_map import lang2eng_mapping | |
| from data.words_map import words_mapping | |
| import gradio as gr | |
| import bcrypt | |
| from config.settings import HF_API_TOKEN | |
| from huggingface_hub import snapshot_download | |
| # from .blur import blur_faces, detect_faces | |
| from retinaface import RetinaFace | |
| from gradio_modal import Modal | |
| import numpy as np | |
| import cv2 | |
| import time | |
| import re | |
| import os | |
| import glob | |
| def update_image(image_url): | |
| try: | |
| headers = {"User-Agent": "Mozilla/5.0"} | |
| response = requests.get(image_url, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| content_type = response.headers.get("Content-Type", "") | |
| if "image" not in content_type: | |
| gr.Error(f"⚠️ URL does not point to a valid image.", duration=5) | |
| return "Error: URL does not point to a valid image." | |
| img = PIL.Image.open(io.BytesIO(response.content)) | |
| img = img.convert("RGB") | |
| return img, Modal(visible=False) | |
| except Exception as e: | |
| # print(f"Error: {str(e)}") | |
| if image_url is None or image_url == "": | |
| return gr.Image(label="Image", elem_id="image_inp"), Modal(visible=False) | |
| else: | |
| return gr.Image(label="Image", value=None, elem_id="image_inp"), Modal(visible=True) | |
| def update_timestamp(): | |
| return gr.Textbox(datetime.datetime.now().timestamp(), label="Timestamp", visible=False) # FIXME visible=False) | |
| def clear_data(): | |
| return (None, None, None, None, None, gr.update(value=None), | |
| gr.update(value=[]), gr.update(value=[]), gr.update(value=[]), | |
| gr.update(value=[]), gr.update(value=[])) | |
| def exit(): | |
| return (None, None, None, gr.Dataset(samples=[]), gr.Markdown("**Loading your data, please wait ...**"), | |
| gr.update(value=None), gr.update(value=None), [None, None, "", ""], gr.update(value=None), | |
| gr.update(value=None), gr.update(value=None), | |
| gr.update(value=None), gr.update(value=None), gr.update(value=None), | |
| gr.update(value=None), gr.update(value=None)) | |
| def validate_inputs(image, ori_img): # is_blurred | |
| # Perform your validation logic here | |
| if image is None: | |
| return gr.Button("Submit", variant="primary", interactive=False), None, None, # False | |
| # Define maximum dimensions | |
| MAX_WIDTH = 1024 | |
| MAX_HEIGHT = 1024 | |
| # Get current dimensions | |
| height, width = image.shape[:2] | |
| # # Check if resizing is needed | |
| # NOTE: for now, let's keep the full image resolution | |
| # if width > MAX_WIDTH or height > MAX_HEIGHT: | |
| # # Calculate scaling factor | |
| # scale = min(MAX_WIDTH/width, MAX_HEIGHT/height) | |
| # # Calculate new dimensions | |
| # new_width = int(width * scale) | |
| # new_height = int(height * scale) | |
| # # Resize image while maintaining aspect ratio | |
| # result_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) | |
| # else: | |
| # result_image = image | |
| result_image = image | |
| if ori_img is None: | |
| # If the original image is None, set it to the resized image | |
| ori_img = gr.State(result_image.copy()) | |
| return gr.Button("Submit", variant="primary", interactive=True), result_image, ori_img # is_blurred | |
| def add_prefix(example, column_name, prefix): | |
| example[column_name] = (f"{prefix}/" + example[column_name]) | |
| return example | |
| def update_user_data(username, password, country, language_choice, HF_DATASET_NAME, local_ds_directory_path): | |
| datasets_list = [] | |
| # Try loading local dataset | |
| try: | |
| snapshot_download( | |
| repo_id=HF_DATASET_NAME, | |
| repo_type="dataset", | |
| local_dir=local_ds_directory_path, # Your target local directory | |
| allow_patterns=f"{country}/{language_choice}/{username}/*", # f"**/{username}/*" | |
| token=HF_API_TOKEN | |
| ) | |
| except Exception as e: | |
| print(f"Snapshot download error: {e}") | |
| # import pdb; pdb.set_trace() | |
| if has_user_json(username, country, language_choice, local_ds_directory_path): | |
| try: | |
| # ds_local = load_dataset(local_ds_directory_path, data_files=f'logged_in_users/**/{username}/**/*.json') # This does not filter by country and language | |
| ds_local = load_dataset(local_ds_directory_path, data_files=f'logged_in_users/{country}/{language_choice}/{username}/**/*.json') | |
| ds_local = ds_local.remove_columns("image_file") | |
| ds_local = ds_local.rename_column("image", "image_file") | |
| ds_local = ds_local.map(add_prefix, fn_kwargs={"column_name": "image_file", "prefix": local_ds_directory_path}) | |
| ds_local = ds_local.cast_column("image_file", Image()) | |
| datasets_list.append(list(ds_local.values())[0]) | |
| except Exception as e: | |
| print(f"Local dataset load error: {e}") | |
| # # Try loading hub dataset | |
| # try: | |
| # ds_hub = load_dataset(HF_DATASET_NAME, data_files=f'**/{username}/**/*.json', token=HF_API_TOKEN) | |
| # ds_hub = ds_hub.cast_column("image_file", Image()) | |
| # datasets_list.append(list(ds_hub.values())[0]) | |
| # except Exception as e: | |
| # print(f"Hub dataset load error: {e}") | |
| # Handle all empty | |
| if not datasets_list: | |
| return gr.Dataset(samples=[]), gr.Markdown("<p style='color: red;'>No data available for this user. Please upload an image.</p>") | |
| dataset = concatenate_datasets(datasets_list) | |
| # TODO: we should link username with password and language and country, otherwise there will be an error when loading with different language and clicking on the example | |
| if username and password: | |
| user_dataset = dataset.filter(lambda x: x['username'] == username and is_password_correct(x['password'], password)) | |
| user_dataset = user_dataset.sort('timestamp', reverse=True) | |
| # Show only unique entries (most recent) | |
| user_ids = set() | |
| samples = [] | |
| for d in user_dataset: | |
| if d['id'] in user_ids: | |
| continue | |
| user_ids.add(d['id']) | |
| if d['excluded']: | |
| continue | |
| # Get additional concepts by category or empty dict if not present | |
| # additional_concepts_by_category = { | |
| # "category1": d.get("category_1_concepts", []), | |
| # "category2": d.get("category_2_concepts", []), | |
| # "category3": d.get("category_3_concepts", []), | |
| # "category4": d.get("category_4_concepts", []), | |
| # "category5": d.get("category_5_concepts", []) | |
| # } | |
| additional_concepts_by_category = [ | |
| d.get("category_1_concepts", [""]), | |
| d.get("category_2_concepts", [""]), | |
| d.get("category_3_concepts", [""]), | |
| d.get("category_4_concepts", [""]), | |
| d.get("category_5_concepts", [""]) | |
| ] | |
| samples.append( | |
| [ | |
| d['image_file'], d['image_url'], d['caption'] or "", d['country'], | |
| d['language'], d['category'], d['concept'], additional_concepts_by_category, d['id']] # d['is_blurred'] | |
| ) | |
| return gr.Dataset(samples=samples), None | |
| else: | |
| # TODO: should we show the entire dataset instead? What about "other data" tab? | |
| return gr.Dataset(samples=[]), None | |
| def update_language(local_storage, metadata_dict, concepts_dict): | |
| country, language, email, password, = local_storage | |
| # my_translator = GoogleTranslator(source='english', target=metadata_dict[country][language]) | |
| categories = concepts_dict[country][lang2eng_mapping.get(language, language)] | |
| if language in words_mapping: | |
| categories_keys_translated = [words_mapping[language].get(cat, cat) for cat in categories.keys()] | |
| else: | |
| categories_keys_translated = list(categories.keys()) | |
| # Get the 5 categories in alphabetical order | |
| categories_list = sorted(list(categories.keys()))[:5] | |
| # Create translated labels for the 5 categories | |
| translated_categories = [] | |
| for cat in categories_list: | |
| if language in words_mapping: | |
| translated_cat = words_mapping[language].get(cat, cat) | |
| else: | |
| translated_cat = cat | |
| translated_categories.append(translated_cat) | |
| fn = metadata_dict[country][language]["Task"] | |
| if os.path.exists(fn): | |
| with open(fn, "r", encoding="utf-8") as f: | |
| TASK_TEXT = f.read() | |
| else: | |
| fn = metadata_dict["USA"]["English"]["Task"] | |
| with open(fn, "r", encoding="utf-8") as f: | |
| TASK_TEXT = f.read() | |
| fn = metadata_dict[country][language]["Instructions"] | |
| if os.path.exists(fn): | |
| with open(metadata_dict[country][language]["Instructions"], "r", encoding="utf-8") as f: | |
| INST_TEXT = f.read() | |
| else: | |
| fn = metadata_dict["USA"]["English"]["Instructions"] | |
| with open(fn, "r", encoding="utf-8") as f: | |
| INST_TEXT = f.read() | |
| return ( | |
| gr.update(label=metadata_dict[country][language]["Country"], value=country), | |
| gr.update(label=metadata_dict[country][language]["Language"], value=language), | |
| gr.update(label=metadata_dict[country][language]["Email"], value=email), | |
| gr.update(label=metadata_dict[country][language]["Password"], value=password), | |
| gr.update(choices=categories_keys_translated, interactive=True, label=metadata_dict[country][language]["Category"], allow_custom_value=False, elem_id="category_btn"), | |
| gr.update(choices=[], interactive=True, label=metadata_dict[country][language]["Concept"], allow_custom_value=True, elem_id="concept_btn"), | |
| gr.update(label=metadata_dict[country][language]["Image"]), | |
| gr.update(label=metadata_dict[country][language]["Image_URL"]), | |
| gr.update(label=metadata_dict[country][language]["Description"]), | |
| gr.Markdown(TASK_TEXT), | |
| gr.Markdown(INST_TEXT), | |
| gr.update(value=metadata_dict[country][language]["Instructs_btn"]), | |
| gr.update(value=metadata_dict[country][language]["Clear_btn"]), | |
| gr.update(value=metadata_dict[country][language]["Submit_btn"]), | |
| gr.Markdown(metadata_dict[country][language]["Saving_text"]), | |
| gr.Markdown(metadata_dict[country][language]["Saved_text"]), | |
| gr.update(label=metadata_dict[country][language]["Timestamp"]), | |
| gr.update(value=metadata_dict[country][language]["Exit_btn"]), | |
| gr.Markdown(metadata_dict[country][language]["Browse_text"]), | |
| gr.Markdown(metadata_dict[country][language]["Loading_msg"]), | |
| # gr.update(choices=categories_keys_translated, interactive=True, label=metadata_dict[country][language].get("Add_Category","Additional Categories (Optional)"), allow_custom_value=False, elem_id="additional_category_btn"), | |
| # gr.update(choices=[], interactive=True, label=metadata_dict[country][language].get("Add_Concept","Additional Concepts (Optional)"), allow_custom_value=True, elem_id="additional_concept_btn"), | |
| gr.update(value=metadata_dict[country][language].get("Hide_all_btn","👤 Hide All Faces")), | |
| gr.update(value=metadata_dict[country][language].get("Hide_btn","👤 Hide Specific Faces")), | |
| gr.update(value=metadata_dict[country][language].get("Unhide_btn","👀 Unhide Faces")), | |
| gr.update(value=metadata_dict[country][language].get("Exclude_btn","Exclude Selected Example")), | |
| gr.update(label=translated_categories[0], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[0]])), | |
| gr.update(label=translated_categories[1], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[1]])), | |
| gr.update(label=translated_categories[2], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[2]])), | |
| gr.update(label=translated_categories[3], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[3]])), | |
| gr.update(label=translated_categories[4], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[4]])), | |
| ) | |
| def update_intro_language(selected_country, selected_language, intro_markdown, metadata): | |
| if selected_language is None: | |
| return intro_markdown | |
| fn = metadata[selected_country][selected_language]["Intro"] | |
| if not os.path.exists(fn): | |
| return intro_markdown | |
| with open(metadata[selected_country][selected_language]["Intro"], "r", encoding="utf-8") as f: | |
| INTRO_TEXT = f.read() | |
| return gr.Markdown(INTRO_TEXT) | |
| def handle_click_example(user_examples, concepts_dict): | |
| print("handle_click_example") | |
| print(user_examples) | |
| ex = [item for item in user_examples] | |
| # print(ex) | |
| image_inp = ex[0] | |
| image_url_inp = ex[1] | |
| long_caption_inp = ex[2] | |
| country_btn = ex[3] | |
| language_btn = ex[4] | |
| category_btn = ex[5] | |
| concept_btn = ex[6] | |
| additional_concepts_by_category = ex[7] | |
| exampleid_btn = ex[8] | |
| additional_concepts_by_category = [[] if (len(cat_concept)==1 and cat_concept[0]=='') else cat_concept for cat_concept in additional_concepts_by_category] | |
| # import pdb; pdb.set_trace() | |
| # # excluded_btn = ex[10] # TODO: add functionality that if True "exclude" button changes to "excluded" | |
| # # is_blurred = ex[11] | |
| # # Get predefined categories in the correct order | |
| # predefined_categories = sorted(list(concepts_dict[country_btn][lang2eng_mapping.get(language_btn, language_btn)].keys()))[:5] | |
| # # Create dropdown values for each category | |
| # dropdown_values = [] | |
| # for category in predefined_categories: | |
| # if additional_concepts_by_category and category in additional_concepts_by_category: | |
| # dropdown_values.append(additional_concepts_by_category[category]) | |
| # else: | |
| # dropdown_values.append(None) | |
| ### TODO: fix additional concepts not saving if categories in other language than English | |
| # # Get the English version of the language | |
| # eng_lang = lang2eng_mapping.get(language_btn, language_btn) | |
| # # Get predefined categories in the correct order | |
| # predefined_categories = sorted(list(concepts_dict[country_btn][eng_lang].keys()))[:5] | |
| # # Create dropdown values for each category | |
| # dropdown_values = [] | |
| # for category in predefined_categories: | |
| # if additional_concepts_by_category and category in additional_concepts_by_category: | |
| # dropdown_values.append(additional_concepts_by_category[category]) | |
| # else: | |
| # dropdown_values.append(None) | |
| # Need to return values for each category dropdown | |
| return [image_inp, image_url_inp, long_caption_inp, exampleid_btn, category_btn, concept_btn] + additional_concepts_by_category + [True] | |
| def is_password_correct(hashed_password, entered_password): | |
| is_valid = bcrypt.checkpw(entered_password.encode(), hashed_password.encode()) | |
| # print("password_check: ", entered_password," ", hashed_password," ", is_valid) | |
| return is_valid | |
| ## Face blurring functions | |
| def detect_faces(image): | |
| """ | |
| Detect faces in an image using RetinaFace. | |
| Args: | |
| image (numpy.ndarray): Input image in BGR | |
| """ | |
| # Start timer | |
| start_time = time.time() | |
| # Detect faces using RetinaFace | |
| detection_start = time.time() | |
| faces = RetinaFace.detect_faces(image, threshold=0.8) | |
| detection_time = time.time() - detection_start | |
| return faces, detection_time | |
| # Hide Faces Button | |
| def select_faces_to_hide(image, blur_faces_ids): | |
| if image is None: | |
| return None, Modal(visible=False), Modal(visible=False), None , "", None, gr.update(value=[]) | |
| else: | |
| # Detect faces | |
| # import pdb; pdb.set_trace() | |
| face_images = image.copy() | |
| faces, detection_time = detect_faces(face_images) | |
| print(f"Detection time: {detection_time:.2f} seconds") | |
| # pdb.set_trace() | |
| # Draw detections with IDs | |
| for face_id, face_data in enumerate(faces.values(), start=1): | |
| # Get face coordinates | |
| facial_area = face_data['facial_area'] | |
| x1, y1, x2, y2 = facial_area | |
| # Draw rectangle around face | |
| cv2.rectangle(face_images, (x1, y1), (x2, y2), (0, 0, 255), 2) | |
| # Add ID text | |
| cv2.putText(face_images, f"ID: {face_id}", (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2) | |
| # Update face count | |
| face_count = len(faces) | |
| blur_faces_ids = gr.update(choices=[f"Face ID: {i}" for i in range(1, face_count + 1)]) | |
| current_faces_info = gr.State(faces) | |
| if face_count == 0: | |
| return image, Modal(visible=False), Modal(visible=True), None, "", None, gr.update(value=[]) | |
| else: | |
| return image, Modal(visible=True), Modal(visible=False), face_images, str(face_count), current_faces_info, blur_faces_ids # | |
| def blur_selected_faces(image, blur_faces_ids, faces_info, face_img, faces_count): # is_blurred | |
| if not blur_faces_ids: | |
| return image, Modal(visible=True), face_img, faces_count, blur_faces_ids # is_blurred | |
| faces = faces_info.value | |
| parsed_faces_ids = blur_faces_ids | |
| parsed_faces_ids = [f"face_{val.split(':')[-1].strip()}" for val in parsed_faces_ids] | |
| # Base blur amount and bounds | |
| MIN_BLUR = 31 # Minimum blur amount (must be odd) | |
| MAX_BLUR = 131 # Maximum blur amount (must be odd) | |
| blurring_start = time.time() | |
| # Process each face | |
| face_count = 0 | |
| if faces and isinstance(faces, dict): | |
| # blur by id | |
| for face_key in parsed_faces_ids: | |
| face_count += 1 | |
| try: | |
| face_data = faces[face_key] | |
| except KeyError: | |
| gr.Warning(f"⚠️ Face ID {face_key.split('_')[-1]} not found in detected faces.", duration=5) | |
| return image, Modal(visible=True), face_img, faces_count, blur_faces_ids # is_blurred | |
| # Get bounding box coordinates | |
| x1, y1, x2, y2 = face_data['facial_area'] | |
| # Calculate face region size | |
| face_width = x2 - x1 | |
| face_height = y2 - y1 | |
| face_size = max(face_width, face_height) | |
| # Calculate adaptive blur amount based on face size | |
| # Scale blur amount between MIN_BLUR and MAX_BLUR based on face size | |
| # Using image width as reference for scaling | |
| img_width = image.shape[1] | |
| blur_amount = int(MIN_BLUR + (MAX_BLUR - MIN_BLUR) * (face_size / img_width)) | |
| # Ensure blur amount is odd | |
| blur_amount = blur_amount if blur_amount % 2 == 1 else blur_amount + 1 | |
| # Ensure within bounds | |
| blur_amount = max(MIN_BLUR, min(MAX_BLUR, blur_amount)) | |
| # Ensure the coordinates are within the image boundaries | |
| ih, iw = image.shape[:2] | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(iw, x2), min(ih, y2) | |
| # Extract face region | |
| face_region = image[y1:y2, x1:x2] | |
| # Apply blur | |
| blurred_face = cv2.GaussianBlur(face_region, (blur_amount, blur_amount), 0) | |
| # Replace face region with blurred version | |
| image[y1:y2, x1:x2] = blurred_face | |
| blurring_time = time.time() - blurring_start | |
| # Print timing information | |
| print(f"Face blurring performance metrics:") | |
| print(f"Face blurring time: {blurring_time:.4f} seconds") | |
| if face_count == 0: | |
| return image, Modal(visible=True), face_img, faces_count, blur_faces_ids | |
| else: | |
| return image, Modal(visible=False), None, None, gr.update(value=[]) | |
| def blur_all_faces(image): | |
| if image is None: | |
| return None, Modal(visible=False) | |
| else: | |
| # Base blur amount and bounds | |
| MIN_BLUR = 31 # Minimum blur amount (must be odd) | |
| MAX_BLUR = 131 # Maximum blur amount (must be odd) | |
| # Start timer | |
| start_time = time.time() | |
| # Detect faces using RetinaFace | |
| detection_start = time.time() | |
| faces = RetinaFace.detect_faces(image) | |
| detection_time = time.time() - detection_start | |
| # Create a copy of the image | |
| output_image = image.copy() | |
| face_count = 0 | |
| blurring_start = time.time() | |
| # Process each face | |
| if faces and isinstance(faces, dict): | |
| for face_key in faces: | |
| face_count += 1 | |
| face_data = faces[face_key] | |
| # Get bounding box coordinates | |
| x1, y1, x2, y2 = face_data['facial_area'] | |
| # Calculate face region size | |
| face_width = x2 - x1 | |
| face_height = y2 - y1 | |
| face_size = max(face_width, face_height) | |
| # Calculate adaptive blur amount based on face size | |
| # Scale blur amount between MIN_BLUR and MAX_BLUR based on face size | |
| # Using image width as reference for scaling | |
| img_width = image.shape[1] | |
| blur_amount = int(MIN_BLUR + (MAX_BLUR - MIN_BLUR) * (face_size / img_width)) | |
| # Ensure blur amount is odd | |
| blur_amount = blur_amount if blur_amount % 2 == 1 else blur_amount + 1 | |
| # Ensure within bounds | |
| blur_amount = max(MIN_BLUR, min(MAX_BLUR, blur_amount)) | |
| # Ensure the coordinates are within the image boundaries | |
| ih, iw = image.shape[:2] | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(iw, x2), min(ih, y2) | |
| # Extract face region | |
| face_region = output_image[y1:y2, x1:x2] | |
| # Apply blur | |
| blurred_face = cv2.GaussianBlur(face_region, (blur_amount, blur_amount), 0) | |
| # Replace face region with blurred version | |
| output_image[y1:y2, x1:x2] = blurred_face | |
| blurring_time = time.time() - blurring_start | |
| total_time = time.time() - start_time | |
| # Print timing information | |
| print(f"Face blurring performance metrics:") | |
| print(f"Total faces detected: {face_count}") | |
| print(f"Face detection time: {detection_time:.4f} seconds") | |
| print(f"Face blurring time: {blurring_time:.4f} seconds") | |
| print(f"Total processing time: {total_time:.4f} seconds") | |
| print(f"Average time per face: {(total_time/max(1, face_count)):.4f} seconds") | |
| if face_count == 0: | |
| return image, Modal(visible=True) | |
| else: | |
| return output_image, Modal(visible=False) | |
| def unhide_faces(img, ori_img): # is_blurred | |
| if img is None: | |
| return None | |
| elif np.array_equal(img, ori_img.value): | |
| return img # is_blurred | |
| else: | |
| return ori_img.value | |
| def check_exclude_fn(image): | |
| if image is None: | |
| gr.Warning("⚠️ No image to exclude.") | |
| return gr.update(visible=False) | |
| else: | |
| return gr.update(visible=True) | |
| def has_user_json(username, country,language_choice, local_ds_directory_path): | |
| """Check if JSON files exist for username pattern.""" | |
| return bool(glob.glob(os.path.join(local_ds_directory_path, "logged_in_users", country, language_choice, username, "**", "*.json"), recursive=True)) |