jatinteamoxio commited on
Commit
4bebfad
·
verified ·
1 Parent(s): 08799a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -13
app.py CHANGED
@@ -1,28 +1,59 @@
1
  import gradio as gr
2
  from transformers import BlipProcessor, BlipForConditionalGeneration
3
  from PIL import Image
 
 
4
 
5
  # Load processor and model
6
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
7
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
8
 
9
- # Captioning function
10
  def caption_image(image):
11
  inputs = processor(images=image, return_tensors="pt")
12
  out = model.generate(**inputs)
13
  caption = processor.decode(out[0], skip_special_tokens=True)
14
  return caption
15
 
16
- # Interface
17
- iface = gr.Interface(
18
- fn=caption_image,
19
- inputs=gr.Image(type="pil"),
20
- outputs="text",
21
- title="Explain this Image",
22
- flagging_mode="never",
23
- # Add the api_name parameter
24
- api_name="predict"
25
- )
 
 
 
 
 
 
 
 
26
 
27
- # Queue for API & launch
28
- iface.queue(api_open=True).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import BlipProcessor, BlipForConditionalGeneration
3
  from PIL import Image
4
+ import base64
5
+ import io
6
 
7
  # Load processor and model
8
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
9
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
10
 
11
+ # Captioning function for direct image input
12
  def caption_image(image):
13
  inputs = processor(images=image, return_tensors="pt")
14
  out = model.generate(**inputs)
15
  caption = processor.decode(out[0], skip_special_tokens=True)
16
  return caption
17
 
18
+ # API endpoint function that can handle base64 images
19
+ def api_caption_image(base64_img):
20
+ try:
21
+ # Remove the data URL prefix if present
22
+ if "," in base64_img:
23
+ base64_img = base64_img.split(",")[1]
24
+
25
+ # Decode base64 to image
26
+ image_bytes = base64.b64decode(base64_img)
27
+ image = Image.open(io.BytesIO(image_bytes))
28
+
29
+ # Process with model
30
+ inputs = processor(images=image, return_tensors="pt")
31
+ out = model.generate(**inputs)
32
+ caption = processor.decode(out[0], skip_special_tokens=True)
33
+ return caption
34
+ except Exception as e:
35
+ return f"Error processing image: {str(e)}"
36
 
37
+ # Create Blocks for more flexibility
38
+ with gr.Blocks() as demo:
39
+ with gr.Tab("Image Captioning"):
40
+ gr.Interface(
41
+ fn=caption_image,
42
+ inputs=gr.Image(type="pil"),
43
+ outputs="text",
44
+ title="Explain this Image",
45
+ flagging_mode="never",
46
+ )
47
+
48
+ # Define the API endpoint explicitly
49
+ gr.Interface(
50
+ fn=api_caption_image,
51
+ inputs=gr.Textbox(), # For base64 input
52
+ outputs="text",
53
+ title="API Endpoint",
54
+ flagging_mode="never",
55
+ api_name="predict" # This is the API endpoint name
56
+ )
57
+
58
+ # Launch with queue and API open
59
+ demo.queue(api_open=True).launch(share=True)