jatinteamoxio commited on
Commit
de09b8d
·
verified ·
1 Parent(s): 4bebfad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -40
app.py CHANGED
@@ -8,25 +8,26 @@ import io
8
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
9
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
10
 
11
- # Captioning function for direct image input
12
- def caption_image(image):
13
- inputs = processor(images=image, return_tensors="pt")
14
- out = model.generate(**inputs)
15
- caption = processor.decode(out[0], skip_special_tokens=True)
16
- return caption
17
-
18
- # API endpoint function that can handle base64 images
19
- def api_caption_image(base64_img):
20
  try:
21
- # Remove the data URL prefix if present
22
- if "," in base64_img:
23
- base64_img = base64_img.split(",")[1]
24
-
25
- # Decode base64 to image
26
- image_bytes = base64.b64decode(base64_img)
27
- image = Image.open(io.BytesIO(image_bytes))
28
-
29
- # Process with model
 
 
 
 
 
 
 
 
30
  inputs = processor(images=image, return_tensors="pt")
31
  out = model.generate(**inputs)
32
  caption = processor.decode(out[0], skip_special_tokens=True)
@@ -34,26 +35,18 @@ def api_caption_image(base64_img):
34
  except Exception as e:
35
  return f"Error processing image: {str(e)}"
36
 
37
- # Create Blocks for more flexibility
38
- with gr.Blocks() as demo:
39
- with gr.Tab("Image Captioning"):
40
- gr.Interface(
41
- fn=caption_image,
42
- inputs=gr.Image(type="pil"),
43
- outputs="text",
44
- title="Explain this Image",
45
- flagging_mode="never",
46
- )
47
-
48
- # Define the API endpoint explicitly
49
- gr.Interface(
50
- fn=api_caption_image,
51
- inputs=gr.Textbox(), # For base64 input
52
- outputs="text",
53
- title="API Endpoint",
54
- flagging_mode="never",
55
- api_name="predict" # This is the API endpoint name
56
- )
57
 
58
- # Launch with queue and API open
59
- demo.queue(api_open=True).launch(share=True)
 
8
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
9
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
10
 
11
+ # Function to process both standard images and base64 strings
12
+ def process_image(input_data):
 
 
 
 
 
 
 
13
  try:
14
+ # Check if input is a base64 string
15
+ if isinstance(input_data, str) and input_data.startswith("data:image"):
16
+ # Extract the base64 part
17
+ base64_data = input_data.split(",")[1]
18
+ image_bytes = base64.b64decode(base64_data)
19
+ image = Image.open(io.BytesIO(image_bytes))
20
+ elif isinstance(input_data, str) and len(input_data) > 100: # Likely a base64 string without prefix
21
+ try:
22
+ image_bytes = base64.b64decode(input_data)
23
+ image = Image.open(io.BytesIO(image_bytes))
24
+ except:
25
+ return "Error: Invalid base64 image format"
26
+ else:
27
+ # Standard image input
28
+ image = input_data
29
+
30
+ # Generate caption
31
  inputs = processor(images=image, return_tensors="pt")
32
  out = model.generate(**inputs)
33
  caption = processor.decode(out[0], skip_special_tokens=True)
 
35
  except Exception as e:
36
  return f"Error processing image: {str(e)}"
37
 
38
+ # Create the demo with both direct image upload and API endpoint
39
+ demo = gr.Interface(
40
+ fn=process_image,
41
+ inputs=[
42
+ gr.Image(type="pil", label="Upload Image")
43
+ ],
44
+ outputs=gr.Textbox(label="Image Caption"),
45
+ title="Image Captioning",
46
+ description="Upload an image to get a caption",
47
+ examples=[],
48
+ allow_flagging="never"
49
+ )
 
 
 
 
 
 
 
 
50
 
51
+ # Important: Expose the same function for API usage
52
+ demo.launch(share=True)