scmlewis commited on
Commit
7a9f97a
·
verified ·
1 Parent(s): 3d6ca1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -56
app.py CHANGED
@@ -6,16 +6,17 @@ from PIL import Image
6
  from collections import deque
7
  import numpy as np
8
 
9
- # Load BLIP model for image captioning
10
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
11
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
12
 
13
- # Load YOLOv5 model for object detection
14
  detect_model = YOLO('yolov5s.pt')
15
 
 
16
  MEMORY_SIZE = 15
17
  last_images = deque([], maxlen=MEMORY_SIZE)
18
- last_texts = deque([], maxlen=MEMORY_SIZE) # will store combined caption + detected objects
19
 
20
  def preprocess_image(image):
21
  if image.mode != "RGB":
@@ -33,84 +34,44 @@ def detect_objects(image):
33
  detected_objs.add(label)
34
  return list(detected_objs)
35
 
36
- def generate_caption_with_objects(image):
37
  image = preprocess_image(image)
38
  inputs = processor(image, return_tensors="pt")
39
  out = model.generate(**inputs, max_length=30, num_beams=5, early_stopping=True)
40
  caption = processor.decode(out[0], skip_special_tokens=True)
41
  detected_objs = detect_objects(image)
42
- tags = ", ".join(detected_objs) if detected_objs else "None"
43
-
44
- combined_text = f"Detected objects: {tags}\nCaption: {caption}"
45
 
46
  # Update session memory
47
  last_images.append(image)
48
- last_texts.append(combined_text)
49
-
50
- return combined_text
51
-
52
- def build_history_ui():
53
- rows = []
54
- for i in range(len(last_images)):
55
- img = last_images[i]
56
- text = last_texts[i]
57
-
58
- cap_box = gr.Textbox(value=text, lines=3, interactive=True, show_label=False)
59
- copy_btn = gr.Button("Copy Text")
60
-
61
- def copy_fn(caption):
62
- return caption
63
 
64
- copy_btn.click(fn=copy_fn, inputs=cap_box, outputs=cap_box)
 
65
 
66
- row = gr.Row([
67
- gr.Image(value=img, interactive=False, show_label=False, elem_id=f"history_img_{i}"),
68
- gr.Column([
69
- cap_box,
70
- copy_btn,
71
- ])
72
- ])
73
- rows.append(row)
74
- return rows
75
 
76
  with gr.Blocks() as iface:
77
  gr.Markdown("# Image Captioning with Object Detection")
78
 
79
- gr.Markdown(
80
- """
81
- Upload an image and click 'Generate Caption'.
82
- The app will display detected objects and a caption together.
83
- Your last 15 images and combined captions are shown below.
84
- """
85
- )
86
 
87
- with gr.Row():
88
- with gr.Column(scale=2):
89
- image_input = gr.Image(type="pil", label="Upload Image")
90
- generate_btn = gr.Button("Generate Caption")
91
- with gr.Column(scale=3):
92
- output_box = gr.Textbox(label="Caption & Detected Objects", lines=6, interactive=True)
93
- copy_btn = gr.Button("Copy Text")
94
 
95
- history_container = gr.Column()
96
 
97
  def on_generate(image):
98
  if image is None:
99
  return "Please upload an image.", []
100
- combined_text = generate_caption_with_objects(image)
101
- history = build_history_ui()
102
- return combined_text, history
103
-
104
- def copy_text(text):
105
- return gr.Textbox.update(value=text, interactive=True)
106
 
107
  generate_btn.click(
108
  fn=on_generate,
109
  inputs=image_input,
110
- outputs=[output_box, history_container],
111
  )
112
 
113
- copy_btn.click(fn=copy_text, inputs=output_box, outputs=output_box)
114
-
115
  if __name__ == "__main__":
116
  iface.launch()
 
6
  from collections import deque
7
  import numpy as np
8
 
9
+ # Load main BLIP model for English captioning
10
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
11
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
12
 
13
+ # Load YOLOv5 small model for object detection using ultralytics package
14
  detect_model = YOLO('yolov5s.pt')
15
 
16
+ # Session memory for last 15 images and captions
17
  MEMORY_SIZE = 15
18
  last_images = deque([], maxlen=MEMORY_SIZE)
19
+ last_captions = deque([], maxlen=MEMORY_SIZE)
20
 
21
  def preprocess_image(image):
22
  if image.mode != "RGB":
 
34
  detected_objs.add(label)
35
  return list(detected_objs)
36
 
37
+ def generate_caption(image):
38
  image = preprocess_image(image)
39
  inputs = processor(image, return_tensors="pt")
40
  out = model.generate(**inputs, max_length=30, num_beams=5, early_stopping=True)
41
  caption = processor.decode(out[0], skip_special_tokens=True)
42
  detected_objs = detect_objects(image)
 
 
 
43
 
44
  # Update session memory
45
  last_images.append(image)
46
+ last_captions.append(caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ tags = ", ".join(detected_objs) if detected_objs else "None"
49
+ gallery = [(img, cap) for img, cap in zip(list(last_images), list(last_captions))]
50
 
51
+ result_text = f"Detected objects: {tags}\nCaption: {caption}"
52
+ return result_text, gallery
 
 
 
 
 
 
 
53
 
54
  with gr.Blocks() as iface:
55
  gr.Markdown("# Image Captioning with Object Detection")
56
 
57
+ image_input = gr.Image(type="pil", label="Upload Image")
58
+
59
+ caption_output = gr.Textbox(label="Caption and Detected Objects", lines=3, interactive=False)
 
 
 
 
60
 
61
+ gallery = gr.Gallery(label="Last 15 Images and Captions", scale=3)
 
 
 
 
 
 
62
 
63
+ generate_btn = gr.Button("Generate Caption")
64
 
65
  def on_generate(image):
66
  if image is None:
67
  return "Please upload an image.", []
68
+ return generate_caption(image)
 
 
 
 
 
69
 
70
  generate_btn.click(
71
  fn=on_generate,
72
  inputs=image_input,
73
+ outputs=[caption_output, gallery]
74
  )
75
 
 
 
76
  if __name__ == "__main__":
77
  iface.launch()