Spaces:
Running
on
T4
Running
on
T4
| import os | |
| import json | |
| import time | |
| from huggingface_hub import HfApi, create_repo, CommitScheduler | |
| import bcrypt | |
| import shutil | |
| import uuid | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| def load_concepts(path="data/concepts.json"): | |
| with open(path, encoding='utf-8') as f: | |
| data = json.load(f) | |
| sorted_data = dict() | |
| for country in sorted(data): | |
| sorted_data[country] = dict() | |
| for lang in sorted(data[country]): | |
| sorted_data[country][lang] = data[country][lang] | |
| return sorted_data | |
| def load_metadata(path="data/metadata.json"): | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| sorted_data = dict() | |
| for country in sorted(data): | |
| sorted_data[country] = dict() | |
| for lang in sorted(data[country]): | |
| sorted_data[country][lang] = data[country][lang] | |
| return sorted_data | |
| class CustomHFDatasetSaver: | |
| def __init__(self, api_token, dataset_name, private=False): | |
| self.api_token = api_token | |
| self.dataset_name = dataset_name | |
| self.private = private | |
| self.api = HfApi() | |
| def setup(self, data_outputs, local_ds_folder): | |
| # create repo is not exist | |
| self.dataset_name = create_repo( | |
| repo_id=self.dataset_name, | |
| token=self.api_token, | |
| private=self.private, | |
| repo_type="dataset", | |
| exist_ok=True, | |
| ).repo_id | |
| # Create the local data folder if not exist | |
| self.local_ds_folder = local_ds_folder | |
| os.makedirs(self.local_ds_folder, exist_ok=True) | |
| self.data_outputs = data_outputs # list of components to read values from | |
| # create scheduler to commit the data to the hub every x minutes | |
| self.scheduler = CommitScheduler( | |
| repo_id=self.dataset_name, | |
| repo_type="dataset", | |
| folder_path=self.local_ds_folder, | |
| every=1, | |
| token=self.api_token, | |
| ) | |
| def validate_data(self, values_dic): | |
| """ | |
| Validates the data before saving to ensure no required fields are empty. | |
| Returns (bool, str) tuple where first value indicates if validation passed | |
| and second value contains error message if validation failed. | |
| """ | |
| # Remove 'image' from required fields since we handle it separately | |
| required_fields = ['country', 'language', 'category', 'concept', 'caption'] | |
| # Check if image is provided (either uploaded or via URL) | |
| image = values_dic.get('image') | |
| image_url = values_dic.get('image_url') | |
| # Check if image exists and is not None | |
| has_image = image is not None and (isinstance(image, dict) or (hasattr(image, 'shape') and image.shape[0] > 0)) | |
| has_url = image_url is not None and image_url.strip() != "" | |
| if not has_image and not has_url: | |
| return False, "Either an image or image URL must be provided" | |
| # Check required fields | |
| for field in required_fields: | |
| value = values_dic.get(field) | |
| if value is None or (isinstance(value, str) and value.strip() == ""): | |
| return False, f"Required field '{field}' cannot be empty" | |
| # Check if image file exists if image path is provided | |
| if has_image and isinstance(image, dict): | |
| if not os.path.exists(image.get('path', '')): | |
| return False, "Image file not found" | |
| return True, "" | |
| #TODO: add a function to check if the user is logged in | |
| def is_logged_in(self): | |
| pass | |
| #TODO: check if the user is logged in (add a decorator to the save function) | |
| def save(self, *values): | |
| # 'values' are the outputs from your data collection components, | |
| # you can map these to field names as needed | |
| values_dic = dict(zip(self.data_outputs, values)) | |
| # print(f"Values received: {values_dic}") | |
| # Validate data before proceeding | |
| is_valid, error_msg = self.validate_data(values_dic) | |
| if not is_valid: | |
| raise gr.Error(error_msg) | |
| # raise ValueError(error_msg) | |
| values_dic['password'] = self.hash_password(values_dic['password']) | |
| # # Process main category and concept | |
| # main_category = values_dic.get('category', '') | |
| # main_concept = values_dic.get('concept', '') | |
| # # Process category-specific concept dropdowns | |
| # additional_concepts_by_category = {} | |
| # # Extract predefined categories and their corresponding dropdowns from values_dic | |
| # predefined_categories = sorted(list(values_dic.get('concepts_dict', {}) | |
| # .get(values_dic.get('country', 'USA'), {}) | |
| # .get(values_dic.get('language', 'English'), {}).keys()))[:5] | |
| # # Process each category dropdown | |
| # for i, category in enumerate(predefined_categories): | |
| # dropdown_key = f'category{i+1}_concepts' | |
| # if dropdown_key in values_dic and values_dic[dropdown_key]: | |
| # # Only add non-empty concept selections | |
| # if values_dic[dropdown_key]: | |
| # additional_concepts_by_category[category] = values_dic[dropdown_key] | |
| ### TODO: fix saving additional concepts if not displayed in English | |
| # # Process category-specific concept dropdowns | |
| # additional_concepts_by_category = {} | |
| # # Extract the country and language | |
| # country = values_dic.get('country', 'USA') | |
| # language = values_dic.get('language', 'English') | |
| # concepts_dict = values_dic.get('concepts_dict', {}) | |
| # lang2eng_mapping = values_dic.get('country_lang_map', {}) | |
| # # Get the English version of the language for dictionary lookup | |
| # eng_lang = lang2eng_mapping.get(language, language) | |
| # # Get the predefined categories in English | |
| # predefined_categories = sorted(list(concepts_dict.get(country, {}).get(eng_lang, {}).keys()))[:5] | |
| # # Process each category dropdown | |
| # for i, category in enumerate(predefined_categories): | |
| # dropdown_key = f'category_{i+1}_concepts' | |
| # if dropdown_key in values_dic and values_dic[dropdown_key]: | |
| # # Only add non-empty concept selections | |
| # additional_concepts_by_category[category] = values_dic[dropdown_key] | |
| current_timestamp = int(time.time() * 1000) | |
| # Create a unique ID for the sample is not provided | |
| if not values_dic.get("id"): | |
| # Missing ID | |
| country, language, category, concept = values_dic.get("country"), values_dic.get("language"), values_dic.get("category"), values_dic.get("concept") | |
| values_dic["id"] = f'{country}_{language}_{category}_{concept}_{current_timestamp}' | |
| #prepare the main directory of the sample | |
| if values_dic.get("username"): | |
| sample_dir = os.path.join("logged_in_users", values_dic["country"], values_dic["language"], values_dic["username"], str(current_timestamp)) | |
| else: | |
| sample_dir = os.path.join("anonymous_users", values_dic["country"], values_dic["language"], str(uuid.uuid4())) | |
| os.makedirs(os.path.join(self.local_ds_folder, sample_dir), exist_ok=True) | |
| # Destination path | |
| dest_image_path = os.path.join(sample_dir, "image.png") | |
| # Source path (to be used for copying the file in the with lock block) | |
| # This is the path of the image file that was uploaded by the user | |
| # I want to save the values_dic['image'] in the dest_image_path | |
| # Convert numpy array to PIL Image and save it | |
| # === | |
| # uploaded_image_path = os.path.join(self.local_ds_folder, dest_image_path) | |
| # img = Image.fromarray(values_dic['image']) | |
| # img.save(uploaded_image_path) | |
| full_dest_path = os.path.join(self.local_ds_folder, dest_image_path) | |
| # Handle different image types | |
| image_data = values_dic['image'] | |
| if isinstance(image_data, dict) and 'path' in image_data: | |
| # New upload case - copy from the uploaded path | |
| uploaded_image_path = image_data['path'] | |
| with self.scheduler.lock: | |
| shutil.copy(uploaded_image_path, full_dest_path) | |
| elif isinstance(image_data, np.ndarray): # not values_dic.get('excluded', False) and | |
| # Exclude case with numpy array - save the array as an image | |
| with self.scheduler.lock: | |
| # Convert numpy array to PIL image and save | |
| img = Image.fromarray(image_data) | |
| img.save(full_dest_path) | |
| elif isinstance(image_data, Image.Image): | |
| # PIL image case | |
| with self.scheduler.lock: | |
| image_data.save(full_dest_path) | |
| values_dic['image'] = dest_image_path | |
| image_file_path_on_hub = f"https://huggingface.co/datasets/{self.dataset_name}/resolve/main/{dest_image_path}" | |
| # print(f"Saving sample: {values}") | |
| # Build the metadata dictionary. | |
| data_dict = { | |
| # in case using windows | |
| "image": values_dic['image'].replace("\\", "/"), | |
| "image_file": image_file_path_on_hub.replace("\\", "/"), | |
| # "image": values_dic['image'], | |
| # "image_file": image_file_path_on_hub, | |
| "image_url": values_dic['image_url'] or "", | |
| "caption": values_dic['caption'] or "", | |
| "country": values_dic['country'] or "", | |
| "language": values_dic['language'] or "", | |
| "category": values_dic['category'] or "", | |
| "concept": values_dic['concept'] or "", | |
| "category_1_concepts": [""] if values_dic.get('category_1_concepts', [""])==[] else values_dic.get('category_1_concepts', [""]), | |
| "category_2_concepts": [""] if values_dic.get('category_2_concepts', [""])==[] else values_dic.get('category_2_concepts', [""]), | |
| "category_3_concepts": [""] if values_dic.get('category_3_concepts', [""])==[] else values_dic.get('category_3_concepts', [""]), | |
| "category_4_concepts": [""] if values_dic.get('category_4_concepts', [""])==[] else values_dic.get('category_4_concepts', [""]), | |
| "category_5_concepts": [""] if values_dic.get('category_5_concepts', [""])==[] else values_dic.get('category_5_concepts', [""]), | |
| "timestamp": current_timestamp, | |
| "username": values_dic['username'] or "", | |
| "password": values_dic['password'] or "", | |
| "id": values_dic['id'], | |
| "excluded": False if values_dic.get('excluded') is None else bool(values_dic.get('excluded')), | |
| # "is_blurred": str(values_dic.get('is_blurred')) | |
| } | |
| print(f"Data dictionary: {data_dict}") | |
| # Define a unique filename for the JSON metadata file (stored in self.folder). | |
| json_filename = f"sample_{current_timestamp}.json" | |
| json_file_path = os.path.join(self.local_ds_folder, sample_dir, json_filename) | |
| with self.scheduler.lock: | |
| # Save the metadata to the sample file in the local dataset folder | |
| with open(json_file_path, "w", encoding="utf-8") as f: | |
| json.dump(data_dict, f, indent=2) | |
| print("Data saved successfully") | |
| def hash_password(self, raw_password): | |
| """ | |
| Hashes a raw password using bcrypt and returns the hashed password. | |
| raw_password (str): The plain text password to be hashed. | |
| str: The hashed password as a string. | |
| """ | |
| hashed_password = bcrypt.hashpw(raw_password.encode(), bcrypt.gensalt()).decode() | |
| return hashed_password | |