victorgg commited on
Commit
ece2cef
·
verified ·
1 Parent(s): 705bdac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -23
app.py CHANGED
@@ -12,9 +12,9 @@ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base
12
  florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
13
 
14
  def generate_caption(image):
15
- # Ensure that the image is a PIL image
16
- if isinstance(image, np.ndarray):
17
- image = Image.fromarray(image) # Convert numpy array to PIL.Image if necessary
18
 
19
  # Prepare the input for the Florence model
20
  inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
@@ -30,16 +30,9 @@ def generate_caption(image):
30
  )
31
 
32
  # Decode the generated text
33
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
34
-
35
- # Post-process the generated text
36
- parsed_answer = florence_processor.post_process_generation(
37
- generated_text,
38
- task="<MORE_DETAILED_CAPTION>",
39
- image_size=(image.width, image.height)
40
- )
41
 
42
- return parsed_answer["<MORE_DETAILED_CAPTION>"]
43
 
44
  # Streamlit UI
45
  st.title("Florence 2 Caption Generator")
@@ -60,20 +53,21 @@ if uploaded_image is not None:
60
  st.write(caption)
61
 
62
  # ✅ API Mode: Handle API Requests
63
- st.experimental_set_query_params() # Ensure Streamlit can handle query params
64
-
65
  def handle_api_request():
66
  """Handle API request by checking URL query parameters."""
67
- query_params = st.experimental_get_query_params()
68
-
69
- if "image" in query_params:
70
- image_base64 = query_params["image"][0] # Get Base64-encoded image
71
- image_bytes = BytesIO(base64.b64decode(image_base64))
72
- image = Image.open(image_bytes)
73
 
74
- caption = generate_caption(image)
75
- st.json({"caption": caption}) # Return JSON response
 
 
 
 
 
 
 
 
76
 
77
  # Check if API mode is enabled
78
- if "image" in st.experimental_get_query_params():
79
  handle_api_request()
 
12
  florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
13
 
14
  def generate_caption(image):
15
+ """Generate a caption for the given image using Florence 2"""
16
+ # Convert image to RGB format to avoid channel errors
17
+ image = image.convert("RGB")
18
 
19
  # Prepare the input for the Florence model
20
  inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
 
30
  )
31
 
32
  # Decode the generated text
33
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
34
 
35
+ return generated_text
36
 
37
  # Streamlit UI
38
  st.title("Florence 2 Caption Generator")
 
53
  st.write(caption)
54
 
55
  # ✅ API Mode: Handle API Requests
 
 
56
  def handle_api_request():
57
  """Handle API request by checking URL query parameters."""
58
+ query_params = st.query_params
 
 
 
 
 
59
 
60
+ if "image" in query_params:
61
+ try:
62
+ image_base64 = query_params["image"]
63
+ image_bytes = BytesIO(base64.b64decode(image_base64))
64
+ image = Image.open(image_bytes).convert("RGB") # Ensure it's RGB
65
+
66
+ caption = generate_caption(image)
67
+ st.json({"caption": caption}) # Return JSON response
68
+ except Exception as e:
69
+ st.json({"error": str(e)})
70
 
71
  # Check if API mode is enabled
72
+ if "image" in st.query_params:
73
  handle_api_request()