Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| # Preprocessing function (if necessary) | |
| def preprocess_text(text): | |
| return text.strip().lower() | |
| # Function to get embeddings for the texts | |
| def get_embeddings(texts): | |
| # Initialize tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
| model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
| # Preprocess text | |
| texts = [preprocess_text(t) for t in texts] | |
| # Tokenize sentences | |
| encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') | |
| # Compute token embeddings | |
| with torch.no_grad(): | |
| model_output = model(**encoded_input) | |
| # Perform pooling and normalize embeddings | |
| attention_mask = encoded_input['attention_mask'] | |
| embeddings = (model_output.last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) | |
| embeddings = embeddings / attention_mask.sum(1, keepdim=True) | |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
| return embeddings | |
| # Function to calculate cosine similarity | |
| def calculate_similarity(text1, text2): | |
| embeddings = get_embeddings([text1, text2]) | |
| similarity = torch.nn.functional.cosine_similarity(embeddings[0].unsqueeze(0), embeddings[1].unsqueeze(0)).item() | |
| return f"Cosine Similarity: {similarity:.4f}" | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=calculate_similarity, | |
| inputs=[ | |
| gr.Textbox(label="Text 1", placeholder="Enter the first text here..."), | |
| gr.Textbox(label="Text 2", placeholder="Enter the second text here...") | |
| ], | |
| outputs="text", | |
| title="Cosine Similarity Calculator", | |
| description="Provide two texts to calculate the cosine similarity between them using embeddings from a transformer model." | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface.launch() | |