scmlewis commited on
Commit
a51fe1b
·
verified ·
1 Parent(s): 3ee7d17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -20
app.py CHANGED
@@ -1,15 +1,17 @@
1
  from transformers import BlipProcessor, BlipForConditionalGeneration, MarianMTModel, MarianTokenizer
 
2
  import torch
3
  import gradio as gr
4
  from PIL import Image
5
  from collections import deque
 
6
 
7
  # Load main BLIP model for English captioning
8
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
9
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
10
 
11
- # Load YOLOv5 small model for object detection (using torch hub)
12
- detect_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
13
 
14
  # Setup MarianMT translation models cache for multilingual captions
15
  translation_models = {
@@ -23,13 +25,16 @@ translation_cache = {}
23
  def get_translation_model(lang_code):
24
  if lang_code not in translation_cache:
25
  model_name, _ = translation_models[lang_code]
26
- tokenizer = MarianTokenizer.from_pretrained(model_name)
27
- model = MarianMTModel.from_pretrained(model_name)
28
- translation_cache[lang_code] = (tokenizer, model)
 
 
 
29
  return translation_cache[lang_code]
30
 
31
  def translate_caption(caption, target_lang):
32
- if target_lang == "English":
33
  return caption
34
  tokenizer, model = get_translation_model(target_lang)
35
  batch = tokenizer([caption], return_tensors="pt")
@@ -48,10 +53,15 @@ def preprocess_image(image):
48
  return image
49
 
50
  def detect_objects(image):
51
- results = detect_model(image)
52
- detected_labels = results.names
53
- objs = [detected_labels[int(x)] for x in results.xyxy[0][:, -1]]
54
- return list(set(objs)) # unique labels
 
 
 
 
 
55
 
56
  def generate_caption(image, language):
57
  image = preprocess_image(image)
@@ -65,20 +75,12 @@ def generate_caption(image, language):
65
  last_images.append(image)
66
  last_captions.append(caption_translated)
67
 
68
- # Format detected objects tags as comma-separated list
69
  tags = ", ".join(detected_objs) if detected_objs else "None"
70
-
71
- # Prepare last images gallery (thumbnails and captions)
72
  gallery = [(img, cap) for img, cap in zip(list(last_images), list(last_captions))]
73
-
74
  result_text = f"Detected objects: {tags}\nCaption ({language}): {caption_translated}"
75
  return result_text, gallery
76
 
77
- # Gradio gallery components expect images as PIL Images or URLs, captions as texts
78
- def gallery_to_components(gallery):
79
- images, captions = zip(*gallery) if gallery else ([], [])
80
- return images, captions
81
-
82
  with gr.Blocks() as iface:
83
  gr.Markdown("# Image Captioning with Object Detection & Multilingual Support")
84
 
@@ -107,4 +109,5 @@ with gr.Blocks() as iface:
107
  outputs=[caption_output, gallery]
108
  )
109
 
110
- iface.launch()
 
 
1
  from transformers import BlipProcessor, BlipForConditionalGeneration, MarianMTModel, MarianTokenizer
2
+ from ultralytics import YOLO
3
  import torch
4
  import gradio as gr
5
  from PIL import Image
6
  from collections import deque
7
+ import numpy as np
8
 
9
  # Load main BLIP model for English captioning
10
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
11
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
12
 
13
+ # Load YOLOv5 small model for object detection using ultralytics package
14
+ detect_model = YOLO('yolov5s.pt')
15
 
16
  # Setup MarianMT translation models cache for multilingual captions
17
  translation_models = {
 
25
  def get_translation_model(lang_code):
26
  if lang_code not in translation_cache:
27
  model_name, _ = translation_models[lang_code]
28
+ if model_name:
29
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
30
+ model = MarianMTModel.from_pretrained(model_name)
31
+ translation_cache[lang_code] = (tokenizer, model)
32
+ else:
33
+ translation_cache[lang_code] = None
34
  return translation_cache[lang_code]
35
 
36
  def translate_caption(caption, target_lang):
37
+ if target_lang == "English" or translation_cache.get(target_lang) is None:
38
  return caption
39
  tokenizer, model = get_translation_model(target_lang)
40
  batch = tokenizer([caption], return_tensors="pt")
 
53
  return image
54
 
55
  def detect_objects(image):
56
+ img_np = np.array(image)
57
+ results = detect_model(img_np)
58
+ detected_objs = set()
59
+ for r in results:
60
+ for box in r.boxes.data.tolist():
61
+ class_id = int(box[-1])
62
+ label = detect_model.names[class_id]
63
+ detected_objs.add(label)
64
+ return list(detected_objs)
65
 
66
  def generate_caption(image, language):
67
  image = preprocess_image(image)
 
75
  last_images.append(image)
76
  last_captions.append(caption_translated)
77
 
 
78
  tags = ", ".join(detected_objs) if detected_objs else "None"
 
 
79
  gallery = [(img, cap) for img, cap in zip(list(last_images), list(last_captions))]
80
+
81
  result_text = f"Detected objects: {tags}\nCaption ({language}): {caption_translated}"
82
  return result_text, gallery
83
 
 
 
 
 
 
84
  with gr.Blocks() as iface:
85
  gr.Markdown("# Image Captioning with Object Detection & Multilingual Support")
86
 
 
109
  outputs=[caption_output, gallery]
110
  )
111
 
112
+ if __name__ == "__main__":
113
+ iface.launch()