SurAyush's picture
example mod
4832e82
"""
SBERT Sentence Similarity Demo
================================
Two models loaded from HuggingFace:
- Vanilla SBERT (train_10)
- SBERT + CWL (train_10, λ=0.1)
Requirements:
pip install gradio transformers torch
"""
import torch
import torch.nn.functional as F
import gradio as gr
from transformers import AutoTokenizer, AutoModel
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
REPO = "SurAyush/sbert-sts-models"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(REPO, subfolder="split_10")
print("Loading Vanilla SBERT (train_10)...")
vanilla_model = AutoModel.from_pretrained(REPO, subfolder="split_10").to(DEVICE)
vanilla_model.eval()
print("Loading SBERT + CWL (train_10, λ=0.1)...")
cwl_model = AutoModel.from_pretrained(REPO, subfolder="sbert_cwl_split_10_lambda_0_1").to(DEVICE)
cwl_model.eval()
MODELS = {
"Vanilla SBERT (train_10)" : vanilla_model,
"SBERT + CWL (train_10, λ=0.1)" : cwl_model,
}
print("All models loaded.\n")
def mean_pool(token_embeddings, attention_mask):
mask_expanded = attention_mask.unsqueeze(-1).float()
sum_embeddings = (token_embeddings * mask_expanded).sum(dim=1)
sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)
return sum_embeddings / sum_mask
@torch.no_grad()
def get_similarity(model, sent1, sent2):
enc = tokenizer(
[sent1, sent2],
max_length=128,
padding=True,
truncation=True,
return_tensors="pt"
).to(DEVICE)
out = model(**enc)
embs = mean_pool(out.last_hidden_state, enc["attention_mask"])
u, v = embs[0].unsqueeze(0), embs[1].unsqueeze(0)
return F.cosine_similarity(u, v).item()
def predict(sentence1, sentence2, model_choice):
if not sentence1.strip() or not sentence2.strip():
return "Please enter both sentences."
model = MODELS[model_choice]
score = get_similarity(model, sentence1.strip(), sentence2.strip())
score = max(0.0, min(1.0, score)) # clip to [0, 1]
# simple label
if score >= 0.7:
label = "🟢 High Similarity"
elif score >= 0.4:
label = "🟡 Moderate Similarity"
else:
label = "🔴 Low Similarity"
return f"{score:.4f}{label}"
with gr.Blocks(title="SBERT Sentence Similarity") as demo:
gr.Markdown("""
# 🔍 Sentence Similarity Demo
**Project:** Improving Low-Resource Sentence Embedding Learning via Augmentation-Based Consistency Regularization
Enter two sentences and select a model to compute their semantic similarity score (0 to 1).
| Score | Meaning |
|---|---|
| 0.7 – 1.0 | High similarity |
| 0.4 – 0.7 | Moderate similarity |
| 0.0 – 0.4 | Low similarity |
""")
with gr.Row():
sent1 = gr.Textbox(
label="Sentence 1",
placeholder="Enter first sentence here...",
lines=3
)
sent2 = gr.Textbox(
label="Sentence 2",
placeholder="Enter second sentence here...",
lines=3
)
model_choice = gr.Radio(
choices=list(MODELS.keys()),
value="Vanilla SBERT (train_10)",
label="Select Model"
)
submit_btn = gr.Button("Compute Similarity", variant="primary")
output = gr.Textbox(
label="Similarity Score",
interactive=False
)
# examples
gr.Examples(
examples=[
["A man is playing a guitar.", "Someone is strumming a musical instrument."],
["She; loves to paint landscapes.", "She enjoys creating nature artwork"],
["The scientist discovered a new element.", "A researcher found a previously unknown substance."],
["He quickly ran to catch the bus.", "He rushed hurriedly to board the vehicle."],
["The, economy; is! recovering: slowly, from. the! recession;", "The economic situation is gradually improving after the downturn."]
],
inputs=[sent1, sent2],
outputs=output,
fn=predict,
cache_examples=False
)
submit_btn.click(
fn=predict,
inputs=[sent1, sent2, model_choice],
outputs=output
)
gr.Markdown("""
---
**Models trained on STS-B benchmark (10% of training data)**
Backbone: `bert-base-uncased` | Pooling: Mean pooling | Metric: Cosine Similarity
HuggingFace: [SurAyush/sbert-sts-models](https://huggingface.co/SurAyush/sbert-sts-models)
""")
if __name__ == "__main__":
demo.launch()