File size: 3,443 Bytes
b114468
7b031dc
 
b114468
b3f4657
 
7b031dc
 
 
 
 
 
 
 
 
 
 
 
 
 
b3f4657
b114468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21f4259
 
b3f4657
 
21f4259
b3f4657
7b031dc
 
 
 
 
 
 
21f4259
b3f4657
21f4259
b3f4657
21f4259
b3f4657
b114468
 
 
21f4259
 
b114468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3f4657
 
b114468
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import json

# Load your model once
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

def get_embedding(text):
    """Generate embedding for a single text"""
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
        # Use mean pooling over token embeddings
        embeddings = outputs.last_hidden_state.mean(dim=1)
        # Normalize the embeddings
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
    return embeddings.squeeze().tolist()

def predict_texts(texts):
    """Generate embeddings for a list of texts (for API compatibility)"""
    if isinstance(texts, str):
        # If single text, convert to list
        texts = [texts]

    if not isinstance(texts, list):
        return "Error: Input must be a list of texts or a single text string"

    # Generate embeddings for each text
    embeddings = []
    for text in texts:
        if isinstance(text, str):
            embedding = get_embedding(text)
            embeddings.append(embedding)
        else:
            return f"Error: All items must be strings, got {type(text)}"

    return embeddings

def predict_single_text(text):
    """Generate embedding for a single text (for Gradio interface)"""
    if not text or not text.strip():
        return "Please enter some text to generate embeddings."

    embedding = get_embedding(text.strip())
    return f"Embedding (first 10 values): {embedding[:10]}...\nFull embedding has {len(embedding)} dimensions."

def predict_api(texts):
    """Handle API calls from backend - expects list of texts directly"""
    try:
        if not isinstance(texts, list):
            return {'error': 'Input must be a list of texts'}

        # Generate embeddings for each text
        embeddings = []
        for text in texts:
            if isinstance(text, str):
                embedding = get_embedding(text)
                embeddings.append(embedding)
            else:
                return {'error': 'All items must be strings'}

        return {'data': embeddings}
    except Exception as e:
        return {'error': str(e)}

# Create API interface (this will create /api/predict endpoint)
api_interface = gr.Interface(
    fn=predict_api,
    inputs=gr.JSON(),  # Expects JSON input directly
    outputs=gr.JSON(), # Returns JSON output directly
    api_name="predict"
)

# Create web interface
web_interface = gr.Interface(
    fn=predict_single_text,
    inputs=gr.Textbox(lines=3, placeholder="Enter text to generate embeddings..."),
    outputs=gr.Textbox(label="Embedding Result"),
    title="Text Embedding Generator",
    description="Generate embeddings for text using sentence-transformers/all-MiniLM-L6-v2 model",
    examples=[
        ["Hello world"],
        ["This is a test sentence for embedding generation."],
        ["Machine learning is transforming the world."]
    ]
)

# Launch both interfaces
if __name__ == '__main__':
    gr.TabbedInterface([web_interface, api_interface], ["Web UI", "API"]).launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True
    )

if __name__ == '__main__':
    iface.launch(server_name="0.0.0.0", server_port=7860, share=True)