Spaces:
Sleeping
Sleeping
| """ | |
| SBERT Sentence Similarity Demo | |
| ================================ | |
| Two models loaded from HuggingFace: | |
| - Vanilla SBERT (train_10) | |
| - SBERT + CWL (train_10, λ=0.1) | |
| Requirements: | |
| pip install gradio transformers torch | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModel | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| REPO = "SurAyush/sbert-sts-models" | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(REPO, subfolder="split_10") | |
| print("Loading Vanilla SBERT (train_10)...") | |
| vanilla_model = AutoModel.from_pretrained(REPO, subfolder="split_10").to(DEVICE) | |
| vanilla_model.eval() | |
| print("Loading SBERT + CWL (train_10, λ=0.1)...") | |
| cwl_model = AutoModel.from_pretrained(REPO, subfolder="sbert_cwl_split_10_lambda_0_1").to(DEVICE) | |
| cwl_model.eval() | |
| MODELS = { | |
| "Vanilla SBERT (train_10)" : vanilla_model, | |
| "SBERT + CWL (train_10, λ=0.1)" : cwl_model, | |
| } | |
| print("All models loaded.\n") | |
| def mean_pool(token_embeddings, attention_mask): | |
| mask_expanded = attention_mask.unsqueeze(-1).float() | |
| sum_embeddings = (token_embeddings * mask_expanded).sum(dim=1) | |
| sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9) | |
| return sum_embeddings / sum_mask | |
| def get_similarity(model, sent1, sent2): | |
| enc = tokenizer( | |
| [sent1, sent2], | |
| max_length=128, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(DEVICE) | |
| out = model(**enc) | |
| embs = mean_pool(out.last_hidden_state, enc["attention_mask"]) | |
| u, v = embs[0].unsqueeze(0), embs[1].unsqueeze(0) | |
| return F.cosine_similarity(u, v).item() | |
| def predict(sentence1, sentence2, model_choice): | |
| if not sentence1.strip() or not sentence2.strip(): | |
| return "Please enter both sentences." | |
| model = MODELS[model_choice] | |
| score = get_similarity(model, sentence1.strip(), sentence2.strip()) | |
| score = max(0.0, min(1.0, score)) # clip to [0, 1] | |
| # simple label | |
| if score >= 0.7: | |
| label = "🟢 High Similarity" | |
| elif score >= 0.4: | |
| label = "🟡 Moderate Similarity" | |
| else: | |
| label = "🔴 Low Similarity" | |
| return f"{score:.4f} — {label}" | |
| with gr.Blocks(title="SBERT Sentence Similarity") as demo: | |
| gr.Markdown(""" | |
| # 🔍 Sentence Similarity Demo | |
| **Project:** Improving Low-Resource Sentence Embedding Learning via Augmentation-Based Consistency Regularization | |
| Enter two sentences and select a model to compute their semantic similarity score (0 to 1). | |
| | Score | Meaning | | |
| |---|---| | |
| | 0.7 – 1.0 | High similarity | | |
| | 0.4 – 0.7 | Moderate similarity | | |
| | 0.0 – 0.4 | Low similarity | | |
| """) | |
| with gr.Row(): | |
| sent1 = gr.Textbox( | |
| label="Sentence 1", | |
| placeholder="Enter first sentence here...", | |
| lines=3 | |
| ) | |
| sent2 = gr.Textbox( | |
| label="Sentence 2", | |
| placeholder="Enter second sentence here...", | |
| lines=3 | |
| ) | |
| model_choice = gr.Radio( | |
| choices=list(MODELS.keys()), | |
| value="Vanilla SBERT (train_10)", | |
| label="Select Model" | |
| ) | |
| submit_btn = gr.Button("Compute Similarity", variant="primary") | |
| output = gr.Textbox( | |
| label="Similarity Score", | |
| interactive=False | |
| ) | |
| # examples | |
| gr.Examples( | |
| examples=[ | |
| ["A man is playing a guitar.", "Someone is strumming a musical instrument."], | |
| ["She; loves to paint landscapes.", "She enjoys creating nature artwork"], | |
| ["The scientist discovered a new element.", "A researcher found a previously unknown substance."], | |
| ["He quickly ran to catch the bus.", "He rushed hurriedly to board the vehicle."], | |
| ["The, economy; is! recovering: slowly, from. the! recession;", "The economic situation is gradually improving after the downturn."] | |
| ], | |
| inputs=[sent1, sent2], | |
| outputs=output, | |
| fn=predict, | |
| cache_examples=False | |
| ) | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[sent1, sent2, model_choice], | |
| outputs=output | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Models trained on STS-B benchmark (10% of training data)** | |
| Backbone: `bert-base-uncased` | Pooling: Mean pooling | Metric: Cosine Similarity | |
| HuggingFace: [SurAyush/sbert-sts-models](https://huggingface.co/SurAyush/sbert-sts-models) | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |