scmlewis commited on
Commit
f0f418e
·
verified ·
1 Parent(s): d5530d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -10
app.py CHANGED
@@ -1,24 +1,50 @@
1
- from transformers import BlipProcessor, BlipForConditionalGeneration
2
- import torch
3
  import gradio as gr
4
  from PIL import Image
5
 
6
- # Load model and processor
7
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
8
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
9
 
10
- def generate_caption(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  inputs = processor(image, return_tensors="pt")
12
- out = model.generate(**inputs)
13
- caption = processor.decode(out[0], skip_special_tokens=True)
14
- return caption
 
 
15
 
16
- # Create Gradio interface
17
  iface = gr.Interface(
18
  fn=generate_caption,
19
- inputs=gr.Image(type="pil"),
 
 
 
 
 
 
20
  outputs="text",
21
- title="Image Captioning App"
 
22
  )
23
 
24
  if __name__ == "__main__":
 
1
+ from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
 
2
  import gradio as gr
3
  from PIL import Image
4
 
5
+ # Load the main image captioning model
6
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
7
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
8
 
9
+ # Load a small language model pipeline for polishing (can change model to a better LLM as needed)
10
+ text_generator = pipeline("text-generation", model="gpt2")
11
+
12
+ def preprocess_image(image):
13
+ # Example: convert to RGB if needed, resize if you want consistent input size
14
+ if image.mode != "RGB":
15
+ image = image.convert("RGB")
16
+ return image
17
+
18
+ def postprocess_caption(raw_caption):
19
+ # Use language model to polish and expand the caption slightly
20
+ # Limit max new tokens to keep the output concise
21
+ polished = text_generator(f"Describe this image in more detail: {raw_caption}", max_length=50, num_return_sequences=1)
22
+ # Extract generated text, remove prompt part
23
+ polished_caption = polished[0]['generated_text'].replace(f"Describe this image in more detail: ", "").strip()
24
+ return polished_caption
25
+
26
+ def generate_caption(image, max_length=30, num_beams=5):
27
+ image = preprocess_image(image)
28
  inputs = processor(image, return_tensors="pt")
29
+ # Generate caption with adjustable parameters for length and quality
30
+ out = model.generate(**inputs, max_length=max_length, num_beams=num_beams, early_stopping=True)
31
+ raw_caption = processor.decode(out[0], skip_special_tokens=True)
32
+ polished_caption = postprocess_caption(raw_caption)
33
+ return polished_caption
34
 
35
+ # Gradio interface with sliders for max_length and num_beams parameters and descriptions
36
  iface = gr.Interface(
37
  fn=generate_caption,
38
+ inputs=[
39
+ gr.Image(type="pil", label="Upload Image"),
40
+ gr.Slider(10, 50, value=30, step=5, label="Caption Max Length",
41
+ info="Controls the length of the caption (higher means longer captions)"),
42
+ gr.Slider(1, 10, value=5, step=1, label="Beam Search Width",
43
+ info="Controls how many caption options the model considers before picking the best")
44
+ ],
45
  outputs="text",
46
+ title="Enhanced Image Captioning App",
47
+ description="Upload an image and get a polished, detailed description. Adjust sliders to control caption length and quality."
48
  )
49
 
50
  if __name__ == "__main__":