|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import torch |
|
|
import json |
|
|
|
|
|
|
|
|
model_name = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModel.from_pretrained(model_name) |
|
|
|
|
|
def get_embedding(text): |
|
|
"""Generate embedding for a single text""" |
|
|
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
embeddings = outputs.last_hidden_state.mean(dim=1) |
|
|
|
|
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
|
return embeddings.squeeze().tolist() |
|
|
|
|
|
def predict_texts(texts): |
|
|
"""Generate embeddings for a list of texts (for API compatibility)""" |
|
|
if isinstance(texts, str): |
|
|
|
|
|
texts = [texts] |
|
|
|
|
|
if not isinstance(texts, list): |
|
|
return "Error: Input must be a list of texts or a single text string" |
|
|
|
|
|
|
|
|
embeddings = [] |
|
|
for text in texts: |
|
|
if isinstance(text, str): |
|
|
embedding = get_embedding(text) |
|
|
embeddings.append(embedding) |
|
|
else: |
|
|
return f"Error: All items must be strings, got {type(text)}" |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def predict_single_text(text): |
|
|
"""Generate embedding for a single text (for Gradio interface)""" |
|
|
if not text or not text.strip(): |
|
|
return "Please enter some text to generate embeddings." |
|
|
|
|
|
embedding = get_embedding(text.strip()) |
|
|
return f"Embedding (first 10 values): {embedding[:10]}...\nFull embedding has {len(embedding)} dimensions." |
|
|
|
|
|
def predict_api(texts): |
|
|
"""Handle API calls from backend - expects list of texts directly""" |
|
|
try: |
|
|
if not isinstance(texts, list): |
|
|
return {'error': 'Input must be a list of texts'} |
|
|
|
|
|
|
|
|
embeddings = [] |
|
|
for text in texts: |
|
|
if isinstance(text, str): |
|
|
embedding = get_embedding(text) |
|
|
embeddings.append(embedding) |
|
|
else: |
|
|
return {'error': 'All items must be strings'} |
|
|
|
|
|
return {'data': embeddings} |
|
|
except Exception as e: |
|
|
return {'error': str(e)} |
|
|
|
|
|
|
|
|
api_interface = gr.Interface( |
|
|
fn=predict_api, |
|
|
inputs=gr.JSON(), |
|
|
outputs=gr.JSON(), |
|
|
api_name="predict" |
|
|
) |
|
|
|
|
|
|
|
|
web_interface = gr.Interface( |
|
|
fn=predict_single_text, |
|
|
inputs=gr.Textbox(lines=3, placeholder="Enter text to generate embeddings..."), |
|
|
outputs=gr.Textbox(label="Embedding Result"), |
|
|
title="Text Embedding Generator", |
|
|
description="Generate embeddings for text using sentence-transformers/all-MiniLM-L6-v2 model", |
|
|
examples=[ |
|
|
["Hello world"], |
|
|
["This is a test sentence for embedding generation."], |
|
|
["Machine learning is transforming the world."] |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
gr.TabbedInterface([web_interface, api_interface], ["Web UI", "API"]).launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=True |
|
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
iface.launch(server_name="0.0.0.0", server_port=7860, share=True) |
|
|
|