GeorgeSherif commited on
Commit
508e6cb
·
1 Parent(s): 5b3d599
Files changed (1) hide show
  1. app.py +62 -62
app.py CHANGED
@@ -4,8 +4,6 @@ import threading
4
  import random
5
  from datasets import load_dataset, Dataset, Features, Value, concatenate_datasets
6
  from huggingface_hub import login
7
- import json
8
- import re
9
 
10
  # Authenticate with Hugging Face
11
  token = os.getenv("HUGGINGFACE_TOKEN")
@@ -15,68 +13,60 @@ else:
15
  print("HUGGINGFACE_TOKEN environment variable not set.")
16
  dataset_name = "GeorgeIbrahim/EGYCOCO" # Replace with your dataset name
17
 
18
- # Load or create the dataset
19
  try:
20
  dataset = load_dataset(dataset_name, split="train")
21
  print("Loaded existing dataset:", dataset)
22
  except Exception as e:
23
- # Create an empty dataset if it doesn't exist
24
- features = Features({
25
- 'image_id': Value(dtype='string'),
26
- 'caption': Value(dtype='string'),
27
- })
28
- dataset = Dataset.from_dict({'image_id': [], 'caption': []}, features=features)
29
- dataset.push_to_hub(dataset_name) # Push the empty dataset to Hugging Face
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  image_folder = "images"
32
  image_files = [f for f in os.listdir(image_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]
33
  lock = threading.Lock()
34
 
35
- with open('nearest_neighbors_with_captions.json', 'r') as f:
36
- results = json.load(f)
37
- # print("Loaded JSON data:", results) # Add this line to verify the data
38
-
39
- def get_caption_for_image_id(image_path):
40
- """
41
- Retrieve the caption for a given image_id from the JSON data.
42
- """
43
- # Extract the numeric part of the image ID
44
- match = re.search(r'_(\d+)\.', image_path)
45
- if match:
46
- image_id = match.group(1).lstrip('0') # Remove leading zeros
47
- print("Searching for image_id:", image_id) # Debugging line
48
-
49
- # Check if image_id is a test image
50
- if image_id in results:
51
- print("Found caption in results:", results[image_id]["caption"]) # Debugging line
52
- return results[image_id]["caption"]
53
-
54
- # If image_id is not a test image, search in nearest neighbors
55
- for test_image_data in results.values():
56
- for neighbor in test_image_data["nearest_neighbors"]:
57
- if neighbor["image_id"] == image_id:
58
- print("Found caption in nearest neighbors:", neighbor["caption"]) # Debugging line
59
- return neighbor["caption"]
60
-
61
- # Return None if the image_id is not found
62
- print("Caption not found for image_id:", image_id) # Debugging line
63
- return None
64
-
65
- # Function to get a random image that hasn’t been annotated or skipped
66
  def get_next_image(session_data):
67
  with lock:
68
- annotated_images = set(dataset["image_id"]) # Set of annotated images
69
- available_images = [img for img in image_files if img not in annotated_images]
 
 
 
 
 
 
 
 
 
70
  # Check if the user already has an image
71
  if session_data["current_image"] is None and available_images:
72
  # Assign a new random image to the user
73
  session_data["current_image"] = random.choice(available_images)
 
74
  return os.path.join(image_folder, session_data["current_image"]) if session_data["current_image"] else None
75
 
76
- # Function to save the annotation to Hugging Face dataset and fetch the next image
77
  def save_annotation(caption, session_data):
78
  if session_data["current_image"] is None:
79
- return gr.update(visible=False), gr.update(value="All images have been annotated!"), gr.update(value="")
80
 
81
  with lock:
82
  image_id = session_data["current_image"]
@@ -85,8 +75,20 @@ def save_annotation(caption, session_data):
85
  if caption.strip().lower() == "skip":
86
  caption = "skipped"
87
 
 
 
 
 
 
 
 
88
  # Add the new annotation as a new row to the dataset
89
- new_data = Dataset.from_dict({"image_id": [image_id], "caption": [caption]})
 
 
 
 
 
90
  global dataset
91
  dataset = concatenate_datasets([dataset, new_data])
92
 
@@ -94,49 +96,47 @@ def save_annotation(caption, session_data):
94
  dataset.push_to_hub(dataset_name)
95
  print("Pushed updated dataset")
96
 
97
- # Clear user's current image so they get a new one next time
98
- session_data["current_image"] = None
 
99
 
100
  # Fetch the next image
101
  next_image = get_next_image(session_data)
102
  if next_image:
103
- next_caption = get_caption_for_image_id(os.path.basename(next_image)) # Retrieve the caption for the new image
104
- return gr.update(value=next_image), gr.update(value=""), gr.update(value=next_caption or "")
105
  else:
106
- return gr.update(visible=False), gr.update(value="All images have been annotated!"), gr.update(value="")
107
 
108
  # Function to skip the current image
109
- # def skip_image(session_data):
110
- # return save_annotation("skip", session_data)
111
 
 
112
  def initialize_interface(session_data):
113
  next_image = get_next_image(session_data)
114
  if next_image:
115
- next_caption = get_caption_for_image_id(os.path.basename(next_image)) # Retrieve caption for initial image
116
- print(next_caption)
117
- return gr.update(value=next_image), gr.update(value=next_caption or "")
118
  else:
119
  return gr.update(visible=False), gr.update(value="All images have been annotated!")
120
 
121
  # Build the Gradio interface
122
  with gr.Blocks() as demo:
123
  gr.Markdown("# Image Captioning Tool")
124
- gr.Markdown("Please provide your caption in Egyptian Arabic 'Masri'")
125
 
126
  session_data = gr.State({"current_image": None}) # Session-specific state
127
 
128
  with gr.Row():
129
  image = gr.Image()
130
  caption = gr.Textbox(placeholder="Enter caption here...")
131
- existing_caption = gr.Textbox(label="Existing Caption", interactive=False) # Display existing caption
132
  submit = gr.Button("Submit")
133
- # skip = gr.Button("Skip") # Skip button
134
 
135
  # Define actions for buttons
136
- submit.click(fn=save_annotation, inputs=[caption, session_data], outputs=[image, caption, existing_caption])
137
- # skip.click(fn=skip_image, inputs=session_data, outputs=[image, caption])
138
 
139
  # Load initial image
140
- demo.load(fn=initialize_interface, inputs=session_data, outputs=[image, existing_caption])
141
 
142
  demo.launch(share=True)
 
4
  import random
5
  from datasets import load_dataset, Dataset, Features, Value, concatenate_datasets
6
  from huggingface_hub import login
 
 
7
 
8
  # Authenticate with Hugging Face
9
  token = os.getenv("HUGGINGFACE_TOKEN")
 
13
  print("HUGGINGFACE_TOKEN environment variable not set.")
14
  dataset_name = "GeorgeIbrahim/EGYCOCO" # Replace with your dataset name
15
 
16
+ # Load the existing dataset or create it if not available
17
  try:
18
  dataset = load_dataset(dataset_name, split="train")
19
  print("Loaded existing dataset:", dataset)
20
  except Exception as e:
21
+ print("Failed to load dataset:", e)
22
+ dataset = None
23
+
24
+ # Check if "annotation_count" exists, if not, add it
25
+ if dataset is not None:
26
+ if "annotation_count" not in dataset.column_names:
27
+ # Define the updated features with annotation_count added
28
+ features = dataset.features.copy()
29
+ features["annotation_count"] = Value(dtype="int32")
30
+
31
+ # Update dataset with new feature, initializing annotation_count based on existing annotations
32
+ dataset = dataset.map(
33
+ lambda row: {"annotation_count": 1 if "val" in row["image_id"] else 0},
34
+ features=features
35
+ )
36
+
37
+ # Push the updated dataset with the new feature to Hugging Face Hub
38
+ dataset.push_to_hub(dataset_name)
39
+ print("Updated dataset with annotation_count and pushed to Hub")
40
 
41
  image_folder = "images"
42
  image_files = [f for f in os.listdir(image_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]
43
  lock = threading.Lock()
44
 
45
+ # Function to get a random image that hasn’t been fully annotated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def get_next_image(session_data):
47
  with lock:
48
+ # Retrieve set of annotated images with counts
49
+ annotated_images = {item["image_id"]: item["annotation_count"] for item in dataset}
50
+
51
+ # Available images filter
52
+ available_images = [
53
+ img for img in image_files
54
+ if img not in annotated_images or
55
+ ("val" in img and annotated_images[img] < 2) or
56
+ ("val" not in img and annotated_images[img] == 0)
57
+ ]
58
+
59
  # Check if the user already has an image
60
  if session_data["current_image"] is None and available_images:
61
  # Assign a new random image to the user
62
  session_data["current_image"] = random.choice(available_images)
63
+
64
  return os.path.join(image_folder, session_data["current_image"]) if session_data["current_image"] else None
65
 
66
+ # Function to save the annotation to the Hugging Face dataset and fetch the next image
67
  def save_annotation(caption, session_data):
68
  if session_data["current_image"] is None:
69
+ return gr.update(visible=False), gr.update(value="All images have been annotated!")
70
 
71
  with lock:
72
  image_id = session_data["current_image"]
 
75
  if caption.strip().lower() == "skip":
76
  caption = "skipped"
77
 
78
+ # Check if image is already in dataset to update count
79
+ existing_image = dataset.filter(lambda x: x["image_id"] == image_id)
80
+ if len(existing_image):
81
+ annotation_count = existing_image[0]["annotation_count"]
82
+ else:
83
+ annotation_count = 0
84
+
85
  # Add the new annotation as a new row to the dataset
86
+ new_data = Dataset.from_dict({
87
+ "image_id": [image_id],
88
+ "caption": [caption],
89
+ "annotation_count": [annotation_count + 1]
90
+ })
91
+
92
  global dataset
93
  dataset = concatenate_datasets([dataset, new_data])
94
 
 
96
  dataset.push_to_hub(dataset_name)
97
  print("Pushed updated dataset")
98
 
99
+ # Clear user's current image if the validation image has been annotated five times
100
+ if ("val" not in image_id) or (annotation_count + 1 >= 5):
101
+ session_data["current_image"] = None
102
 
103
  # Fetch the next image
104
  next_image = get_next_image(session_data)
105
  if next_image:
106
+ return gr.update(value=next_image), gr.update(value="")
 
107
  else:
108
+ return gr.update(visible=False), gr.update(value="All images have been annotated!")
109
 
110
  # Function to skip the current image
111
+ def skip_image(session_data):
112
+ return save_annotation("skip", session_data)
113
 
114
+ # Function to initialize the interface
115
  def initialize_interface(session_data):
116
  next_image = get_next_image(session_data)
117
  if next_image:
118
+ return gr.update(value=next_image), gr.update(value="")
 
 
119
  else:
120
  return gr.update(visible=False), gr.update(value="All images have been annotated!")
121
 
122
  # Build the Gradio interface
123
  with gr.Blocks() as demo:
124
  gr.Markdown("# Image Captioning Tool")
125
+ gr.Markdown("Please provide a caption for each image displayed. Click 'Submit' after writing your caption, or type 'skip' if you don’t want to annotate this image.")
126
 
127
  session_data = gr.State({"current_image": None}) # Session-specific state
128
 
129
  with gr.Row():
130
  image = gr.Image()
131
  caption = gr.Textbox(placeholder="Enter caption here...")
 
132
  submit = gr.Button("Submit")
133
+ skip = gr.Button("Skip") # Skip button
134
 
135
  # Define actions for buttons
136
+ submit.click(fn=save_annotation, inputs=[caption, session_data], outputs=[image, caption])
137
+ skip.click(fn=skip_image, inputs=session_data, outputs=[image, caption])
138
 
139
  # Load initial image
140
+ demo.load(fn=initialize_interface, inputs=session_data, outputs=[image, caption])
141
 
142
  demo.launch(share=True)