Spaces:
Build error
Build error
| import gradio as gr | |
| import os | |
| import threading | |
| import random | |
| from datasets import load_dataset, Dataset, Features, Value, concatenate_datasets | |
| from huggingface_hub import login | |
| # Authenticate with Hugging Face | |
| token = os.getenv("HUGGINGFACE_TOKEN") | |
| if token: | |
| login(token=token) | |
| else: | |
| print("HUGGINGFACE_TOKEN environment variable not set.") | |
| dataset_name = "GeorgeIbrahim/EGYCOCO" # Replace with your dataset name | |
| # Load or create the dataset with train and val splits | |
| try: | |
| dataset = load_dataset(dataset_name) | |
| print("Loaded existing dataset:", dataset) | |
| except Exception as e: | |
| # Create empty datasets for train and val splits if they don't exist | |
| features = Features({ | |
| 'image_id': Value(dtype='string'), | |
| 'caption': Value(dtype='string'), | |
| }) | |
| train_dataset = Dataset.from_dict({'image_id': [], 'caption': []}, features=features) | |
| val_dataset = Dataset.from_dict({'image_id': [], 'caption': []}, features=features) | |
| dataset = {"train": train_dataset, "val": val_dataset} | |
| # Push empty splits to Hugging Face | |
| dataset["train"].push_to_hub(f"{dataset_name}", split="train") | |
| dataset["val"].push_to_hub(f"{dataset_name}", split="val") | |
| # Verify that both splits are in the dataset | |
| if "train" not in dataset: | |
| dataset["train"] = Dataset.from_dict({'image_id': [], 'caption': []}, features=features) | |
| if "val" not in dataset: | |
| dataset["val"] = Dataset.from_dict({'image_id': [], 'caption': []}, features=features) | |
| image_folder = "test" | |
| image_files = [f for f in os.listdir(image_folder) if f.endswith(('.png', '.jpg', '.jpeg'))] | |
| lock = threading.Lock() | |
| # Function to get the appropriate split from the image ID | |
| def get_split_from_image_id(image_id): | |
| if "train" in image_id: | |
| return "train" | |
| elif "val" in image_id: | |
| return "val" | |
| else: | |
| raise ValueError("Image ID does not contain a valid split identifier (train/val).") | |
| # Function to get a random image that hasn’t been annotated or skipped | |
| def get_next_image(session_data): | |
| with lock: | |
| annotated_images = set(dataset["train"]["image_id"]) | set(dataset["val"]["image_id"]) # Combine annotated image IDs from both splits | |
| available_images = [img for img in image_files if img not in annotated_images] | |
| # Check if the user already has an image | |
| if session_data["current_image"] is None and available_images: | |
| # Assign a new random image to the user | |
| session_data["current_image"] = random.choice(available_images) | |
| return os.path.join(image_folder, session_data["current_image"]) if session_data["current_image"] else None | |
| # Function to save the annotation to the correct split and fetch the next image | |
| def save_annotation(caption, session_data): | |
| if session_data["current_image"] is None: | |
| return gr.update(visible=False), gr.update(value="All images have been annotated!") | |
| with lock: | |
| image_id = session_data["current_image"] | |
| split = get_split_from_image_id(image_id) | |
| # Save caption or "skipped" based on user input | |
| if caption.strip().lower() == "skip": | |
| caption = "skipped" | |
| # Add the new annotation as a new row to the appropriate split | |
| new_data = Dataset.from_dict({"image_id": [image_id], "caption": [caption]}) | |
| global dataset | |
| dataset[split] = concatenate_datasets([dataset[split], new_data]) | |
| # Save updated split to Hugging Face | |
| dataset[split].push_to_hub(dataset_name, split=split) | |
| print(f"Pushed updated dataset for split: {split}") | |
| # Clear user's current image so they get a new one next time | |
| session_data["current_image"] = None | |
| # Fetch the next image | |
| next_image = get_next_image(session_data) | |
| if next_image: | |
| return gr.update(value=next_image), gr.update(value="") | |
| else: | |
| return gr.update(visible=False), gr.update(value="All images have been annotated!") | |
| # Function to skip the current image | |
| def skip_image(session_data): | |
| return save_annotation("skip", session_data) | |
| # Function to initialize the interface | |
| def initialize_interface(session_data): | |
| next_image = get_next_image(session_data) | |
| if next_image: | |
| return gr.update(value=next_image), gr.update(value="") | |
| else: | |
| return gr.update(visible=False), gr.update(value="All images have been annotated!") | |
| # Build the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Image Captioning Tool") | |
| gr.Markdown("Please provide a caption for each image displayed. Click 'Submit' after writing your caption, or type 'skip' if you don’t want to annotate this image.") | |
| session_data = gr.State({"current_image": None}) # Session-specific state | |
| with gr.Row(): | |
| image = gr.Image() | |
| caption = gr.Textbox(placeholder="Enter caption here...") | |
| submit = gr.Button("Submit") | |
| skip = gr.Button("Skip") # Skip button | |
| # Define actions for buttons | |
| submit.click(fn=save_annotation, inputs=[caption, session_data], outputs=[image, caption]) | |
| skip.click(fn=skip_image, inputs=session_data, outputs=[image, caption]) | |
| # Load initial image | |
| demo.load(fn=initialize_interface, inputs=session_data, outputs=[image, caption]) | |
| demo.launch(share=True) | |