samiulhaq's picture
Upload 5 files
d06da6c verified
import os
import base64
from similarity_score import JinaV4SimilarityMapper
def save_base64_image(base64_str, output_path):
"""Helper to decode base64 string and save as a PNG file."""
with open(output_path, "wb") as f:
f.write(base64.b64decode(base64_str))
print(f"Saved: {output_path}")
def visual_grounding_heatmaps(text_query, image_source, task):
print(f"Processing query: '{text_query}'")
print("Generating heatmaps...")
for t in task:
for i in range(len(text_query)):
# 3. Generate Heatmaps
try:
mapper = JinaV4SimilarityMapper(task=t)
tokens, heatmaps, g_score = mapper.get_token_similarity_maps(
query=text_query[i],
image=image_source
)
# 4. Save Results
output_dir = f"heatmap_results_{i}_{t}"
os.makedirs(output_dir, exist_ok=True)
#save g_score in a text file
with open(os.path.join(output_dir, f"g_{t}_{i}_score.txt"), "w") as f:
f.write(str(g_score))
print(f"\nFound {len(tokens)} valid tokens_score.", g_score)
for token in tokens:
if token in heatmaps:
# Create a safe filename for the token
safe_token_name = "".join([c if c.isalnum() else "_" for c in token])
filename = f"heatmap_{safe_token_name}.png"
output_path = os.path.join(output_dir, filename)
# Decode and save
save_base64_image(heatmaps[token], output_path)
print("\nAll heatmaps saved successfully!")
except Exception as e:
print(f"An error occurred: {e}")
def main():
# 1. Initialize the Mapper
# Use client_type="web" if you don't have the model locally and want to use the API class provided
# Use client_type="local" if you have the model weights and want to run it on GPU
# Note: For "web" mode, you might need to set an API key in the JinaEmbeddingsClient class or passing it if modified.
# The provided code has "Bearer Not Set" by default.
#task = [retrieval, text-matching]
# 2. Define Inputs
# You can use a local file path or a URL
image_source = "cyclists.jpg"
text_query = ["A group of cyclists riding nearby the ocean", "A group of cyclists riding nearby the ocean"]
mapper = JinaV4SimilarityMapper(task = 'retrieval')
results = mapper.calculate_multimodal_consistency(text_query[0], text_query[1], image_source)
#example result values:
"""
{
"MMSS": round(mmss, 4),
"Final_Compound_Score": round(final_score, 4),
"Text_Fidelity (Src-Tgt)": round(score_src_tgt, 4),
"Visual_Grounding (Tgt-Img)": round(score_tgt_img, 4),
"Image_Relevance (Src-Img)": round(score_src_img, 4),
"Fusion_Weight": round(lambda_weight, 4)
}
"""
print("Multimodal Consistency Results:")
print(results)
if __name__ == "__main__":
main()