Spaces:
Sleeping
Sleeping
File size: 4,532 Bytes
fa29ef9 4832e82 fa29ef9 4832e82 fa29ef9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | """
SBERT Sentence Similarity Demo
================================
Two models loaded from HuggingFace:
- Vanilla SBERT (train_10)
- SBERT + CWL (train_10, λ=0.1)
Requirements:
pip install gradio transformers torch
"""
import torch
import torch.nn.functional as F
import gradio as gr
from transformers import AutoTokenizer, AutoModel
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
REPO = "SurAyush/sbert-sts-models"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(REPO, subfolder="split_10")
print("Loading Vanilla SBERT (train_10)...")
vanilla_model = AutoModel.from_pretrained(REPO, subfolder="split_10").to(DEVICE)
vanilla_model.eval()
print("Loading SBERT + CWL (train_10, λ=0.1)...")
cwl_model = AutoModel.from_pretrained(REPO, subfolder="sbert_cwl_split_10_lambda_0_1").to(DEVICE)
cwl_model.eval()
MODELS = {
"Vanilla SBERT (train_10)" : vanilla_model,
"SBERT + CWL (train_10, λ=0.1)" : cwl_model,
}
print("All models loaded.\n")
def mean_pool(token_embeddings, attention_mask):
mask_expanded = attention_mask.unsqueeze(-1).float()
sum_embeddings = (token_embeddings * mask_expanded).sum(dim=1)
sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)
return sum_embeddings / sum_mask
@torch.no_grad()
def get_similarity(model, sent1, sent2):
enc = tokenizer(
[sent1, sent2],
max_length=128,
padding=True,
truncation=True,
return_tensors="pt"
).to(DEVICE)
out = model(**enc)
embs = mean_pool(out.last_hidden_state, enc["attention_mask"])
u, v = embs[0].unsqueeze(0), embs[1].unsqueeze(0)
return F.cosine_similarity(u, v).item()
def predict(sentence1, sentence2, model_choice):
if not sentence1.strip() or not sentence2.strip():
return "Please enter both sentences."
model = MODELS[model_choice]
score = get_similarity(model, sentence1.strip(), sentence2.strip())
score = max(0.0, min(1.0, score)) # clip to [0, 1]
# simple label
if score >= 0.7:
label = "🟢 High Similarity"
elif score >= 0.4:
label = "🟡 Moderate Similarity"
else:
label = "🔴 Low Similarity"
return f"{score:.4f} — {label}"
with gr.Blocks(title="SBERT Sentence Similarity") as demo:
gr.Markdown("""
# 🔍 Sentence Similarity Demo
**Project:** Improving Low-Resource Sentence Embedding Learning via Augmentation-Based Consistency Regularization
Enter two sentences and select a model to compute their semantic similarity score (0 to 1).
| Score | Meaning |
|---|---|
| 0.7 – 1.0 | High similarity |
| 0.4 – 0.7 | Moderate similarity |
| 0.0 – 0.4 | Low similarity |
""")
with gr.Row():
sent1 = gr.Textbox(
label="Sentence 1",
placeholder="Enter first sentence here...",
lines=3
)
sent2 = gr.Textbox(
label="Sentence 2",
placeholder="Enter second sentence here...",
lines=3
)
model_choice = gr.Radio(
choices=list(MODELS.keys()),
value="Vanilla SBERT (train_10)",
label="Select Model"
)
submit_btn = gr.Button("Compute Similarity", variant="primary")
output = gr.Textbox(
label="Similarity Score",
interactive=False
)
# examples
gr.Examples(
examples=[
["A man is playing a guitar.", "Someone is strumming a musical instrument."],
["She; loves to paint landscapes.", "She enjoys creating nature artwork"],
["The scientist discovered a new element.", "A researcher found a previously unknown substance."],
["He quickly ran to catch the bus.", "He rushed hurriedly to board the vehicle."],
["The, economy; is! recovering: slowly, from. the! recession;", "The economic situation is gradually improving after the downturn."]
],
inputs=[sent1, sent2],
outputs=output,
fn=predict,
cache_examples=False
)
submit_btn.click(
fn=predict,
inputs=[sent1, sent2, model_choice],
outputs=output
)
gr.Markdown("""
---
**Models trained on STS-B benchmark (10% of training data)**
Backbone: `bert-base-uncased` | Pooling: Mean pooling | Metric: Cosine Similarity
HuggingFace: [SurAyush/sbert-sts-models](https://huggingface.co/SurAyush/sbert-sts-models)
""")
if __name__ == "__main__":
demo.launch() |