Spaces:
Sleeping
Sleeping
| import os | |
| os.system("pip install fairseq2 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.6.0/cu124 -q") | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| import torch | |
| import requests | |
| from PIL import Image | |
| from transformers import SiglipImageProcessor, SiglipVisionModel | |
| from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from io import BytesIO | |
| from transformers.image_utils import load_image | |
| cos = nn.CosineSimilarity() | |
| model_path = hf_hub_download( | |
| repo_id="Sibgat-Ul/SONAR-Image_enc", | |
| filename="best_sonar.pth", | |
| repo_type="model" | |
| ) | |
| language_mapping = { | |
| "English": "eng_Latn", | |
| "Bengali": "ben_Beng", | |
| "French": "fra_Latn" | |
| } | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # -------- Load Image Encoder -------- | |
| class SonarImageEnc(nn.Module): | |
| def __init__(self, path="google/siglip2-base-patch16-384", initial_temperature=0.07): | |
| super().__init__() | |
| self.model = SiglipVisionModel.from_pretrained(path, torch_dtype="auto") | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| self.projection = nn.Sequential( | |
| nn.Linear(self.model.config.hidden_size, 2048), | |
| nn.GELU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(2048, 1024), | |
| nn.LayerNorm(1024, eps=1e-5), | |
| ) | |
| for param in self.projection.parameters(): | |
| param.requires_grad = True | |
| self.temp_s = nn.Parameter(torch.log(torch.tensor(10.0))) | |
| self.bias = nn.Parameter(torch.tensor(-10.0)) | |
| self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1.0) / initial_temperature)) | |
| def forward(self, pixel_values): | |
| vision_outputs = self.model(pixel_values=pixel_values) | |
| pooled_output = vision_outputs.pooler_output | |
| embeddings = self.projection(pooled_output) | |
| self.logit_scale.data.clamp_( | |
| min=torch.log(torch.tensor(1.0).to(device) / torch.tensor(0.001).to(device)), | |
| max=torch.log(torch.tensor(1.0).to(device) / torch.tensor(100.0).to(device)) | |
| ) | |
| return embeddings, torch.exp(self.logit_scale), torch.exp(self.temp_s), self.bias | |
| # Load processor and models | |
| processor = SiglipImageProcessor.from_pretrained("google/siglip2-base-patch16-384") | |
| t2t_model_emb = TextToEmbeddingModelPipeline( | |
| encoder="text_sonar_basic_encoder", | |
| tokenizer="text_sonar_basic_encoder", | |
| device=device, | |
| dtype=torch.float16, | |
| ) | |
| img_encoder = SonarImageEnc().to(device).eval() | |
| img_encoder.load_state_dict(torch.load(model_path, map_location=device)) | |
| # -------- Similarity Scoring -------- | |
| def compute_similarity( | |
| image, image_url, | |
| option_a, option_b, option_c, option_d, | |
| lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d | |
| ): | |
| if not image: | |
| try: | |
| headers = { | |
| "User-Agent": "Mozilla/5.0" | |
| } | |
| response = requests.get(image_url, headers=headers) | |
| response.raise_for_status() | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| except Exception as e: | |
| return None, {"Error": f"Image could not be loaded: {str(e)}"} | |
| # Preprocess image | |
| inputs = processor(image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| image_emb, _, _, _ = img_encoder(inputs.pixel_values) | |
| image_emb = image_emb.to(device, torch.float16) | |
| # Map languages | |
| lang_codes = [ | |
| language_mapping[lang_opt_a], | |
| language_mapping[lang_opt_b], | |
| language_mapping[lang_opt_c], | |
| language_mapping[lang_opt_d], | |
| ] | |
| texts = [option_a, option_b, option_c, option_d] | |
| # Get embeddings per option with corresponding language | |
| text_embeddings = [] | |
| for text, lang in zip(texts, lang_codes): | |
| emb = t2t_model_emb.predict([text], source_lang=lang) | |
| text_embeddings.append(emb) | |
| text_embeddings = torch.cat(text_embeddings, dim=0).to(device) | |
| scores = cos(image_emb, text_embeddings) | |
| results = { | |
| f"Option {chr(65+i)}": round(score.item(), 3) | |
| for i, score in enumerate(scores) | |
| } | |
| results = { | |
| k: f"{round(v * 100, 2)}%" | |
| for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True) | |
| } | |
| return image, results | |
| # -------- Gradio UI -------- | |
| with gr.Blocks(fill_height=True) as demo: | |
| gr.Markdown("## 🔍 SONAR: Image-Text Similarity Scorer") | |
| gr.Markdown("#### Upload an Image or provide an URL.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_url = gr.Textbox(label="Image URL", value="http://images.cocodataset.org/val2017/000000039769.jpg") | |
| with gr.Row(): | |
| option_a = gr.Textbox(label="Option A", value="Two cats in a bed.") | |
| lang_opt_a = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") | |
| with gr.Row(): | |
| option_b = gr.Textbox(label="Option B", value="Two cat with two remotes.") | |
| lang_opt_b = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") | |
| with gr.Row(): | |
| option_c = gr.Textbox(label="Option C", value="Two remotes.") | |
| lang_opt_c = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") | |
| with gr.Row(): | |
| option_d = gr.Textbox(label="Option D", value="Two cats.") | |
| lang_opt_d = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") | |
| # language = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Select Language") | |
| with gr.Column(): | |
| image_input = gr.Image(label="Upload an image", type="pil") | |
| btn = gr.Button("Done") | |
| with gr.Row(): | |
| img_output = gr.Image(label="Input Image", type="pil", width=300, height=300) | |
| result_output = gr.JSON(label="Similarity Scores") | |
| btn.click( | |
| fn=compute_similarity, | |
| inputs=[ | |
| image_input, image_url, | |
| option_a, option_b, option_c, option_d, | |
| lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d | |
| ], | |
| outputs=[img_output, result_output] | |
| ) | |
| demo.launch() | |