scmlewis commited on
Commit
4671d3e
·
verified ·
1 Parent(s): 54766e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -41
app.py CHANGED
@@ -6,15 +6,17 @@ from PIL import Image
6
  from collections import deque
7
  import numpy as np
8
 
9
- # Load BLIP model
10
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
11
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
12
 
13
- # Load YOLOv5 model
14
  detect_model = YOLO('yolov5s.pt')
15
 
 
16
  MEMORY_SIZE = 15
17
- last_texts = deque([], maxlen=MEMORY_SIZE)
 
18
 
19
  def preprocess_image(image):
20
  if image.mode != "RGB":
@@ -35,62 +37,41 @@ def detect_objects(image):
35
  def generate_caption(image):
36
  image = preprocess_image(image)
37
  inputs = processor(image, return_tensors="pt")
38
-
39
  out = model.generate(**inputs, max_length=30, num_beams=5, early_stopping=True)
40
  caption = processor.decode(out[0], skip_special_tokens=True)
41
  detected_objs = detect_objects(image)
42
- tags = ", ".join(detected_objs) if detected_objs else "None"
43
-
44
- combined_text = f"Detected objects: {tags}\nCaption: {caption}"
45
 
46
- last_texts.append(combined_text)
 
 
47
 
48
- return combined_text
 
49
 
50
- def build_history_table():
51
- headers = ["Past Outputs (Click text to copy)"]
52
- data = [[text] for text in reversed(last_texts)]
53
- return headers, data
54
 
55
  with gr.Blocks() as iface:
56
  gr.Markdown("# Image Captioning with Object Detection")
57
 
58
- with gr.Column(scale=2):
59
- image_input = gr.Image(type="pil", label="Upload Image")
60
- generate_btn = gr.Button("Generate Caption")
61
- caption_output = gr.Textbox(
62
- label="Caption and Detected Objects",
63
- lines=6,
64
- interactive=True
65
- )
66
- copy_btn = gr.Button("Copy Output")
67
-
68
- history_table = gr.Dataframe(
69
- headers=["Session History"],
70
- datatype=["str"],
71
- interactive=True,
72
- row_count=(0, MEMORY_SIZE),
73
- col_count=1,
74
- wrap=True
75
- )
76
 
77
  def on_generate(image):
78
  if image is None:
79
- return "Please upload an image.", (["Session History"], [])
80
- combined = generate_caption(image)
81
- headers, data = build_history_table()
82
- return combined, (headers, data)
83
-
84
- def copy_output(text):
85
- return gr.Textbox.update(value=text, interactive=True)
86
 
87
  generate_btn.click(
88
  fn=on_generate,
89
  inputs=image_input,
90
- outputs=[caption_output, history_table]
91
  )
92
 
93
- copy_btn.click(fn=copy_output, inputs=caption_output, outputs=caption_output)
94
-
95
  if __name__ == "__main__":
96
  iface.launch()
 
6
  from collections import deque
7
  import numpy as np
8
 
9
+ # Load main BLIP model for English captioning
10
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
11
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
12
 
13
+ # Load YOLOv5 small model for object detection using ultralytics package
14
  detect_model = YOLO('yolov5s.pt')
15
 
16
+ # Session memory for last 15 images and captions
17
  MEMORY_SIZE = 15
18
+ last_images = deque([], maxlen=MEMORY_SIZE)
19
+ last_captions = deque([], maxlen=MEMORY_SIZE)
20
 
21
  def preprocess_image(image):
22
  if image.mode != "RGB":
 
37
  def generate_caption(image):
38
  image = preprocess_image(image)
39
  inputs = processor(image, return_tensors="pt")
 
40
  out = model.generate(**inputs, max_length=30, num_beams=5, early_stopping=True)
41
  caption = processor.decode(out[0], skip_special_tokens=True)
42
  detected_objs = detect_objects(image)
 
 
 
43
 
44
+ # Update session memory
45
+ last_images.append(image)
46
+ last_captions.append(caption)
47
 
48
+ tags = ", ".join(detected_objs) if detected_objs else "None"
49
+ gallery = [(img, cap) for img, cap in zip(list(last_images), list(last_captions))]
50
 
51
+ result_text = f"Detected objects: {tags}\nCaption: {caption}"
52
+ return result_text, gallery
 
 
53
 
54
  with gr.Blocks() as iface:
55
  gr.Markdown("# Image Captioning with Object Detection")
56
 
57
+ image_input = gr.Image(type="pil", label="Upload Image")
58
+
59
+ caption_output = gr.Textbox(label="Caption and Detected Objects", lines=3, interactive=False)
60
+
61
+ gallery = gr.Gallery(label="Last 15 Images and Captions", scale=3)
62
+
63
+ generate_btn = gr.Button("Generate Caption")
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def on_generate(image):
66
  if image is None:
67
+ return "Please upload an image.", []
68
+ return generate_caption(image)
 
 
 
 
 
69
 
70
  generate_btn.click(
71
  fn=on_generate,
72
  inputs=image_input,
73
+ outputs=[caption_output, gallery]
74
  )
75
 
 
 
76
  if __name__ == "__main__":
77
  iface.launch()