scmlewis commited on
Commit
98c6b6d
·
verified ·
1 Parent(s): 3fdf4eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -74
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import BlipProcessor, BlipForConditionalGeneration, MarianMTModel, MarianTokenizer
2
  from ultralytics import YOLO
3
  import torch
4
  import gradio as gr
@@ -6,47 +6,16 @@ from PIL import Image
6
  from collections import deque
7
  import numpy as np
8
 
9
- # Load BLIP model for English captioning
10
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
11
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
12
 
13
- # Load YOLOv5 small model for detection
14
  detect_model = YOLO('yolov5s.pt')
15
 
16
- # Setup MarianMT translation models cache for multilingual captions
17
- translation_models = {
18
- "English": None,
19
- "French": ("Helsinki-NLP/opus-mt-en-fr", "Helsinki-NLP/opus-mt-fr-en"),
20
- "Spanish": ("Helsinki-NLP/opus-mt-en-es", "Helsinki-NLP/opus-mt-es-en"),
21
- "German": ("Helsinki-NLP/opus-mt-en-de", "Helsinki-NLP/opus-mt-de-en"),
22
- }
23
- translation_cache = {}
24
-
25
- def get_translation_model(lang_code):
26
- if lang_code not in translation_cache:
27
- model_name, _ = translation_models[lang_code]
28
- if model_name:
29
- tokenizer = MarianTokenizer.from_pretrained(model_name)
30
- model = MarianMTModel.from_pretrained(model_name)
31
- translation_cache[lang_code] = (tokenizer, model)
32
- else:
33
- translation_cache[lang_code] = None
34
- return translation_cache[lang_code]
35
-
36
- def translate_caption(caption, target_lang):
37
- if target_lang == "English" or translation_cache.get(target_lang) is None:
38
- return caption
39
- tokenizer, model = get_translation_model(target_lang)
40
- batch = tokenizer([caption], return_tensors="pt")
41
- gen = model.generate(**batch)
42
- translated = tokenizer.decode(gen[0], skip_special_tokens=True)
43
- return translated
44
-
45
  MEMORY_SIZE = 15
46
  last_images = deque([], maxlen=MEMORY_SIZE)
47
- last_captions = deque([], maxlen=MEMORY_SIZE)
48
- last_objects = deque([], maxlen=MEMORY_SIZE)
49
- last_languages = deque([], maxlen=MEMORY_SIZE)
50
 
51
  def preprocess_image(image):
52
  if image.mode != "RGB":
@@ -64,99 +33,84 @@ def detect_objects(image):
64
  detected_objs.add(label)
65
  return list(detected_objs)
66
 
67
- def generate_caption(image, language):
68
  image = preprocess_image(image)
69
  inputs = processor(image, return_tensors="pt")
70
  out = model.generate(**inputs, max_length=30, num_beams=5, early_stopping=True)
71
- caption_en = processor.decode(out[0], skip_special_tokens=True)
72
- caption_translated = translate_caption(caption_en, language)
73
  detected_objs = detect_objects(image)
 
 
 
74
 
75
  # Update session memory
76
  last_images.append(image)
77
- last_captions.append(caption_translated)
78
- last_objects.append(detected_objs)
79
- last_languages.append(language)
80
-
81
- tags = ", ".join(detected_objs) if detected_objs else "None"
82
 
83
- return caption_translated, tags
84
 
85
  def build_history_ui():
86
- # Build list of Gradio Rows containing image, caption textbox and copy button
87
  rows = []
88
  for i in range(len(last_images)):
89
  img = last_images[i]
90
- cap = last_captions[i]
91
- obj = last_objects[i]
92
- lang = last_languages[i]
93
 
94
- cap_box = gr.Textbox(value=cap, lines=2, interactive=True, show_label=False)
95
-
96
- copy_btn = gr.Button("Copy Caption")
97
 
98
  def copy_fn(caption):
99
  return caption
100
 
101
- # Bind copy button inside lambda to close over correct caption_box
102
  copy_btn.click(fn=copy_fn, inputs=cap_box, outputs=cap_box)
103
 
104
  row = gr.Row([
105
  gr.Image(value=img, interactive=False, show_label=False, elem_id=f"history_img_{i}"),
106
  gr.Column([
107
- gr.Markdown(f"**Caption ({lang}):**"),
108
  cap_box,
109
  copy_btn,
110
- gr.Markdown(f"**Detected Objects:** {', '.join(obj) if obj else 'None'}")
111
  ])
112
  ])
113
  rows.append(row)
114
  return rows
115
 
116
  with gr.Blocks() as iface:
117
- gr.Markdown("# Image Captioning with Object Detection & Multilingual Support")
 
118
  gr.Markdown(
119
  """
120
- Upload an image, select the caption language, then click 'Generate Caption'.
121
- The app generates a caption and detected object tags.
122
- Your last 15 images and captions are displayed below for easy copying and reference.
123
  """
124
  )
125
 
126
- language = gr.Dropdown(
127
- label="Select Caption Language",
128
- choices=["English", "French", "Spanish", "German"],
129
- value="English"
130
- )
131
-
132
  with gr.Row():
133
  with gr.Column(scale=2):
134
  image_input = gr.Image(type="pil", label="Upload Image")
135
  generate_btn = gr.Button("Generate Caption")
136
  with gr.Column(scale=3):
137
- caption_output = gr.Textbox(label="Caption", lines=3, interactive=True)
138
- object_output = gr.Textbox(label="Detected Objects", lines=2, interactive=False)
139
- copy_btn = gr.Button("Copy Caption Text")
140
 
141
  history_container = gr.Column()
142
 
143
- def on_generate(image, language):
144
  if image is None:
145
- return "Please upload an image.", "", []
146
- caption, objects = generate_caption(image, language)
147
  history = build_history_ui()
148
- return caption, objects, history
149
 
150
  def copy_text(text):
151
  return gr.Textbox.update(value=text, interactive=True)
152
 
153
  generate_btn.click(
154
  fn=on_generate,
155
- inputs=[image_input, language],
156
- outputs=[caption_output, object_output, history_container]
157
  )
158
 
159
- copy_btn.click(fn=copy_text, inputs=caption_output, outputs=caption_output)
160
 
161
  if __name__ == "__main__":
162
  iface.launch()
 
1
+ from transformers import BlipProcessor, BlipForConditionalGeneration
2
  from ultralytics import YOLO
3
  import torch
4
  import gradio as gr
 
6
  from collections import deque
7
  import numpy as np
8
 
9
+ # Load BLIP model for image captioning
10
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
11
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
12
 
13
+ # Load YOLOv5 model for object detection
14
  detect_model = YOLO('yolov5s.pt')
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  MEMORY_SIZE = 15
17
  last_images = deque([], maxlen=MEMORY_SIZE)
18
+ last_texts = deque([], maxlen=MEMORY_SIZE) # will store combined caption + detected objects
 
 
19
 
20
  def preprocess_image(image):
21
  if image.mode != "RGB":
 
33
  detected_objs.add(label)
34
  return list(detected_objs)
35
 
36
+ def generate_caption_with_objects(image):
37
  image = preprocess_image(image)
38
  inputs = processor(image, return_tensors="pt")
39
  out = model.generate(**inputs, max_length=30, num_beams=5, early_stopping=True)
40
+ caption = processor.decode(out[0], skip_special_tokens=True)
 
41
  detected_objs = detect_objects(image)
42
+ tags = ", ".join(detected_objs) if detected_objs else "None"
43
+
44
+ combined_text = f"Detected objects: {tags}\nCaption: {caption}"
45
 
46
  # Update session memory
47
  last_images.append(image)
48
+ last_texts.append(combined_text)
 
 
 
 
49
 
50
+ return combined_text
51
 
52
  def build_history_ui():
 
53
  rows = []
54
  for i in range(len(last_images)):
55
  img = last_images[i]
56
+ text = last_texts[i]
 
 
57
 
58
+ cap_box = gr.Textbox(value=text, lines=3, interactive=True, show_label=False)
59
+ copy_btn = gr.Button("Copy Text")
 
60
 
61
  def copy_fn(caption):
62
  return caption
63
 
 
64
  copy_btn.click(fn=copy_fn, inputs=cap_box, outputs=cap_box)
65
 
66
  row = gr.Row([
67
  gr.Image(value=img, interactive=False, show_label=False, elem_id=f"history_img_{i}"),
68
  gr.Column([
 
69
  cap_box,
70
  copy_btn,
 
71
  ])
72
  ])
73
  rows.append(row)
74
  return rows
75
 
76
  with gr.Blocks() as iface:
77
+ gr.Markdown("# Image Captioning with Object Detection")
78
+
79
  gr.Markdown(
80
  """
81
+ Upload an image and click 'Generate Caption'.
82
+ The app will display detected objects and a caption together.
83
+ Your last 15 images and combined captions are shown below.
84
  """
85
  )
86
 
 
 
 
 
 
 
87
  with gr.Row():
88
  with gr.Column(scale=2):
89
  image_input = gr.Image(type="pil", label="Upload Image")
90
  generate_btn = gr.Button("Generate Caption")
91
  with gr.Column(scale=3):
92
+ output_box = gr.Textbox(label="Caption & Detected Objects", lines=6, interactive=True)
93
+ copy_btn = gr.Button("Copy Text")
 
94
 
95
  history_container = gr.Column()
96
 
97
+ def on_generate(image):
98
  if image is None:
99
+ return "Please upload an image.", []
100
+ combined_text = generate_caption_with_objects(image)
101
  history = build_history_ui()
102
+ return combined_text, history
103
 
104
  def copy_text(text):
105
  return gr.Textbox.update(value=text, interactive=True)
106
 
107
  generate_btn.click(
108
  fn=on_generate,
109
+ inputs=image_input,
110
+ outputs=[output_box, history_container],
111
  )
112
 
113
+ copy_btn.click(fn=copy_text, inputs=output_box, outputs=output_box)
114
 
115
  if __name__ == "__main__":
116
  iface.launch()