scmlewis commited on
Commit
3ee7d17
·
verified ·
1 Parent(s): 0f88d29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -25
app.py CHANGED
@@ -1,39 +1,110 @@
1
- from transformers import BlipProcessor, BlipForConditionalGeneration
 
2
  import gradio as gr
3
  from PIL import Image
 
4
 
5
- # Load the BLIP image captioning model and processor
6
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
7
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def preprocess_image(image):
10
  if image.mode != "RGB":
11
  image = image.convert("RGB")
12
  return image
13
 
14
- def generate_caption(image, max_length, num_beams):
15
- max_length = int(max_length) # cast to int to ensure correct type
16
- num_beams = int(num_beams) # cast to int to ensure correct type
 
 
17
 
 
18
  image = preprocess_image(image)
19
  inputs = processor(image, return_tensors="pt")
20
- out = model.generate(**inputs, max_length=max_length, num_beams=num_beams, early_stopping=True)
21
- caption = processor.decode(out[0], skip_special_tokens=True)
22
- return caption
23
-
24
- iface = gr.Interface(
25
- fn=generate_caption,
26
- inputs=[
27
- gr.Image(type="pil", label="Upload Image"),
28
- gr.Slider(10, 50, value=30, step=5, label="Caption Max Length",
29
- info="Controls length of caption (higher is longer)"),
30
- gr.Slider(1, 10, value=5, step=1, label="Beam Search Width",
31
- info="Higher means better caption quality (slower processing)")
32
- ],
33
- outputs=gr.Textbox(label="Generated Caption"),
34
- title="Simple Image Captioning App",
35
- description="Upload an image. The model generates a simple caption describing it."
36
- )
37
-
38
- if __name__ == "__main__":
39
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BlipProcessor, BlipForConditionalGeneration, MarianMTModel, MarianTokenizer
2
+ import torch
3
  import gradio as gr
4
  from PIL import Image
5
+ from collections import deque
6
 
7
+ # Load main BLIP model for English captioning
8
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
9
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
10
 
11
+ # Load YOLOv5 small model for object detection (using torch hub)
12
+ detect_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
13
+
14
+ # Setup MarianMT translation models cache for multilingual captions
15
+ translation_models = {
16
+ "English": None,
17
+ "French": ("Helsinki-NLP/opus-mt-en-fr", "Helsinki-NLP/opus-mt-fr-en"),
18
+ "Spanish": ("Helsinki-NLP/opus-mt-en-es", "Helsinki-NLP/opus-mt-es-en"),
19
+ "German": ("Helsinki-NLP/opus-mt-en-de", "Helsinki-NLP/opus-mt-de-en"),
20
+ }
21
+ translation_cache = {}
22
+
23
+ def get_translation_model(lang_code):
24
+ if lang_code not in translation_cache:
25
+ model_name, _ = translation_models[lang_code]
26
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
27
+ model = MarianMTModel.from_pretrained(model_name)
28
+ translation_cache[lang_code] = (tokenizer, model)
29
+ return translation_cache[lang_code]
30
+
31
+ def translate_caption(caption, target_lang):
32
+ if target_lang == "English":
33
+ return caption
34
+ tokenizer, model = get_translation_model(target_lang)
35
+ batch = tokenizer([caption], return_tensors="pt")
36
+ gen = model.generate(**batch)
37
+ translated = tokenizer.decode(gen[0], skip_special_tokens=True)
38
+ return translated
39
+
40
+ # Session memory for last 15 images and captions
41
+ MEMORY_SIZE = 15
42
+ last_images = deque([], maxlen=MEMORY_SIZE)
43
+ last_captions = deque([], maxlen=MEMORY_SIZE)
44
+
45
  def preprocess_image(image):
46
  if image.mode != "RGB":
47
  image = image.convert("RGB")
48
  return image
49
 
50
+ def detect_objects(image):
51
+ results = detect_model(image)
52
+ detected_labels = results.names
53
+ objs = [detected_labels[int(x)] for x in results.xyxy[0][:, -1]]
54
+ return list(set(objs)) # unique labels
55
 
56
+ def generate_caption(image, language):
57
  image = preprocess_image(image)
58
  inputs = processor(image, return_tensors="pt")
59
+ out = model.generate(**inputs, max_length=30, num_beams=5, early_stopping=True)
60
+ caption_en = processor.decode(out[0], skip_special_tokens=True)
61
+ caption_translated = translate_caption(caption_en, language)
62
+ detected_objs = detect_objects(image)
63
+
64
+ # Update session memory
65
+ last_images.append(image)
66
+ last_captions.append(caption_translated)
67
+
68
+ # Format detected objects tags as comma-separated list
69
+ tags = ", ".join(detected_objs) if detected_objs else "None"
70
+
71
+ # Prepare last images gallery (thumbnails and captions)
72
+ gallery = [(img, cap) for img, cap in zip(list(last_images), list(last_captions))]
73
+
74
+ result_text = f"Detected objects: {tags}\nCaption ({language}): {caption_translated}"
75
+ return result_text, gallery
76
+
77
+ # Gradio gallery components expect images as PIL Images or URLs, captions as texts
78
+ def gallery_to_components(gallery):
79
+ images, captions = zip(*gallery) if gallery else ([], [])
80
+ return images, captions
81
+
82
+ with gr.Blocks() as iface:
83
+ gr.Markdown("# Image Captioning with Object Detection & Multilingual Support")
84
+
85
+ language = gr.Dropdown(
86
+ label="Select Caption Language",
87
+ choices=["English", "French", "Spanish", "German"],
88
+ value="English"
89
+ )
90
+
91
+ image_input = gr.Image(type="pil", label="Upload Image")
92
+
93
+ caption_output = gr.Textbox(label="Caption and Detected Objects", lines=3, interactive=False)
94
+
95
+ gallery = gr.Gallery(label="Last 15 Images and Captions").style(columns=3, object_fit="contain", height="auto")
96
+
97
+ generate_btn = gr.Button("Generate Caption")
98
+
99
+ def on_generate(image, language):
100
+ if image is None:
101
+ return "Please upload an image.", []
102
+ return generate_caption(image, language)
103
+
104
+ generate_btn.click(
105
+ fn=on_generate,
106
+ inputs=[image_input, language],
107
+ outputs=[caption_output, gallery]
108
+ )
109
+
110
+ iface.launch()