scmlewis commited on
Commit
9517d28
·
verified ·
1 Parent(s): 229a996

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -15
app.py CHANGED
@@ -6,14 +6,14 @@ from PIL import Image
6
  from collections import deque
7
  import numpy as np
8
 
9
- # Load main BLIP model for English captioning
10
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
11
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
12
 
13
- # Load YOLOv5 small model for object detection using ultralytics package
14
  detect_model = YOLO('yolov5s.pt')
15
 
16
- # Setup MarianMT translation models cache for multilingual captions
17
  translation_models = {
18
  "English": None,
19
  "French": ("Helsinki-NLP/opus-mt-en-fr", "Helsinki-NLP/opus-mt-fr-en"),
@@ -42,10 +42,10 @@ def translate_caption(caption, target_lang):
42
  translated = tokenizer.decode(gen[0], skip_special_tokens=True)
43
  return translated
44
 
45
- # Session memory for last 15 images and captions
46
  MEMORY_SIZE = 15
47
  last_images = deque([], maxlen=MEMORY_SIZE)
48
  last_captions = deque([], maxlen=MEMORY_SIZE)
 
49
 
50
  def preprocess_image(image):
51
  if image.mode != "RGB":
@@ -74,40 +74,97 @@ def generate_caption(image, language):
74
  # Update session memory
75
  last_images.append(image)
76
  last_captions.append(caption_translated)
 
77
 
78
  tags = ", ".join(detected_objs) if detected_objs else "None"
79
- gallery = [(img, cap) for img, cap in zip(list(last_images), list(last_captions))]
80
-
81
  result_text = f"Detected objects: {tags}\nCaption ({language}): {caption_translated}"
82
- return result_text, gallery
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  with gr.Blocks() as iface:
85
  gr.Markdown("# Image Captioning with Object Detection & Multilingual Support")
86
 
 
 
 
 
 
 
 
 
 
87
  language = gr.Dropdown(
88
  label="Select Caption Language",
89
  choices=["English", "French", "Spanish", "German"],
90
  value="English"
91
  )
92
 
93
- image_input = gr.Image(type="pil", label="Upload Image")
94
-
95
- caption_output = gr.Textbox(label="Caption and Detected Objects", lines=3, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Fixed: removed style() method, added scale parameter to Gallery
98
- gallery = gr.Gallery(label="Last 15 Images and Captions", scale=3)
99
 
100
- generate_btn = gr.Button("Generate Caption")
 
 
 
 
 
101
 
102
  def on_generate(image, language):
103
  if image is None:
104
  return "Please upload an image.", []
105
- return generate_caption(image, language)
 
 
106
 
107
  generate_btn.click(
108
  fn=on_generate,
109
  inputs=[image_input, language],
110
- outputs=[caption_output, gallery]
 
 
 
 
 
 
111
  )
112
 
113
  if __name__ == "__main__":
 
6
  from collections import deque
7
  import numpy as np
8
 
9
+ # Load BLIP model for English captioning
10
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
11
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
12
 
13
+ # Load YOLOv5 small model for detection
14
  detect_model = YOLO('yolov5s.pt')
15
 
16
+ # MarianMT translation models cache
17
  translation_models = {
18
  "English": None,
19
  "French": ("Helsinki-NLP/opus-mt-en-fr", "Helsinki-NLP/opus-mt-fr-en"),
 
42
  translated = tokenizer.decode(gen[0], skip_special_tokens=True)
43
  return translated
44
 
 
45
  MEMORY_SIZE = 15
46
  last_images = deque([], maxlen=MEMORY_SIZE)
47
  last_captions = deque([], maxlen=MEMORY_SIZE)
48
+ last_languages = deque([], maxlen=MEMORY_SIZE)
49
 
50
  def preprocess_image(image):
51
  if image.mode != "RGB":
 
74
  # Update session memory
75
  last_images.append(image)
76
  last_captions.append(caption_translated)
77
+ last_languages.append(language)
78
 
79
  tags = ", ".join(detected_objs) if detected_objs else "None"
 
 
80
  result_text = f"Detected objects: {tags}\nCaption ({language}): {caption_translated}"
81
+
82
+ # Prepare table data for last 15 images with copyable captions and copy buttons
83
+ history_rows = []
84
+ for img, cap, lang in zip(last_images, last_captions, last_languages):
85
+ history_rows.append([img, cap])
86
+
87
+ return result_text, history_rows
88
+
89
+ def gallery_to_table(history_rows):
90
+ # history_rows is list of [PIL image, caption text]
91
+ headers = ["Image", "Caption (click to copy)"]
92
+ data = []
93
+ for img, cap in history_rows:
94
+ data.append([
95
+ img,
96
+ gr.Textbox.update(value=cap, interactive=True)
97
+ ])
98
+ return headers, data
99
 
100
  with gr.Blocks() as iface:
101
  gr.Markdown("# Image Captioning with Object Detection & Multilingual Support")
102
 
103
+ gr.Markdown("""
104
+
105
+ This app generates descriptive captions for your uploaded images, detects objects within them,
106
+ and supports multilingual captions. Upload an image, then click 'Generate Caption' to see results.
107
+
108
+ Your last 15 images and captions are saved below for easy reference and copying.
109
+
110
+ """)
111
+
112
  language = gr.Dropdown(
113
  label="Select Caption Language",
114
  choices=["English", "French", "Spanish", "German"],
115
  value="English"
116
  )
117
 
118
+ with gr.Row():
119
+ with gr.Column(scale=2):
120
+ image_input = gr.Image(type="pil", label="Upload Image")
121
+ generate_btn = gr.Button("Generate Caption")
122
+ with gr.Column(scale=3):
123
+ caption_output = gr.Textbox(
124
+ label="Caption & Detected Objects",
125
+ lines=4,
126
+ interactive=True
127
+ )
128
+ copy_btn = gr.Button("Copy Caption Text")
129
+
130
+ # History table with thumbnails and copyable captions
131
+ history_table = gr.Dataframe(
132
+ headers=["Image", "Caption"],
133
+ row_count=(MEMORY_SIZE, MEMORY_SIZE),
134
+ col_count=2,
135
+ datatype=["image", "str"],
136
+ interactive=False,
137
+ wrap=True,
138
+ label="Last 15 Images and Captions"
139
+ )
140
 
141
+ def copy_text(caption_text):
142
+ return gr.update(value=caption_text)
143
 
144
+ def update_history(history_rows):
145
+ # Convert to format compatible with gr.Dataframe
146
+ data = []
147
+ for img, cap in history_rows:
148
+ data.append([img, cap])
149
+ return data
150
 
151
  def on_generate(image, language):
152
  if image is None:
153
  return "Please upload an image.", []
154
+ result_text, history_rows = generate_caption(image, language)
155
+ history_data = update_history(history_rows)
156
+ return result_text, history_data
157
 
158
  generate_btn.click(
159
  fn=on_generate,
160
  inputs=[image_input, language],
161
+ outputs=[caption_output, history_table]
162
+ )
163
+
164
+ copy_btn.click(
165
+ fn=lambda text: text,
166
+ inputs=[caption_output],
167
+ outputs=[caption_output]
168
  )
169
 
170
  if __name__ == "__main__":