emilbm commited on
Commit
5a5e912
·
0 Parent(s):

init project

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .DS_Store
2
+ __pycache__/
3
+ *.pyc
4
+ .venv
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user pyproject.toml ./
10
+ RUN pip install --no-cache-dir .
11
+
12
+ COPY --chown=user app ./app
13
+
14
+ # Start the app with Uvicorn
15
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
Makefile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ APP_DIR := $(CURDIR)/app
2
+ TESTS_DIR := $(CURDIR)/tests
3
+
4
+ format:
5
+ uv run black $(APP_DIR) $(TESTS_DIR)/*.py
6
+ uv run ruff check $(APP_DIR) $(TESTS_DIR) --fix
7
+
8
+ lint:
9
+ uv run black --check $(APP_DIR) $(TESTS_DIR)/*.py
10
+ uv run ruff check $(APP_DIR) $(TESTS_DIR)
11
+ uv run mypy $(APP_DIR) $(TESTS_DIR)
12
+
13
+ test:
14
+ uv run pytest ${TESTS_DIR}
15
+
README.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Text2vector
3
+ emoji: 📊
4
+ colorFrom: purple
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ short_description: Create a vector embedding from text
9
+ ---
10
+
11
+ # Embedding API
12
+
13
+ API to call an embedding model ([intfloat/multilingual-e5-large](https://huggingface.co/intfloat/multilingual-e5-large)) for generating multilingual text embeddings.<br>
14
+ The embedding model takes a text string and converts it into 1024 dimension vector.<br>
15
+ Using a `POST` request to the `/embed` endpoint with a list of texts, the API returns their corresponding embeddings.<br>
16
+ A maximum of 2000 characters per text is enforced to avoid truncation, and thereby loss of information, by the tokenizer.<br>
17
+ Each text must start with either "query: " or "passage: ".<br>
18
+
19
+
20
+ The API is deployed at a Hugging Face Docker space where the Swagger UI can be acccessed at:<br>
21
+ [https://emilbm-text2vector.hf.space/docs](https://emilbm-text2vector.hf.space/docs)
22
+
23
+
24
+ ## Features
25
+
26
+ - FastAPI-based REST API
27
+ - `/embed` endpoint for generating embeddings from a list of texts
28
+ - `/health` endpoint for checking the API status
29
+ - Uses HuggingFace Transformers and PyTorch
30
+ - Includes linting and unit tests
31
+ - Dockerfile for containerization
32
+ - CI/CD with GitHub Actions to build, lint, test, and deploy to Hugging Face
33
+
34
+ ## Local Development
35
+ ### Requirements
36
+
37
+ - Python 3.12+
38
+ - [UV](https://docs.astral.sh/uv/)
39
+ - (Optional) Docker
40
+ ### Installation
41
+
42
+ 1. **Clone the repository:**
43
+ ```sh
44
+ git clone https://github.com/EmilbMadsen/embedding-api.git
45
+ cd embedding-api
46
+ ```
47
+
48
+ 2. **Create a virtual environment and activate it:**
49
+ ```sh
50
+ uv venv
51
+ source .venv/bin/activate
52
+ ```
53
+
54
+ 3. **Install dependencies:**
55
+ ```sh
56
+ uv sync
57
+ ```
58
+
59
+ ### Formatting, Linting and Unit Tests
60
+ - **Formatting (with Black and Ruff) and linting (with Black, Ruff, and MyPy):**
61
+ ```sh
62
+ make format
63
+ make lint
64
+ ```
65
+ - **Run unit tests:**
66
+ ```sh
67
+ make test
68
+ ```
69
+
70
+ ### Running Locally (without Docker)
71
+
72
+ Start the API server with Uvicorn:
73
+
74
+ ```sh
75
+ uvicorn app.main:app --reload --port 7860
76
+ ```
77
+
78
+ ### Running Locally (with Docker)
79
+ Build and start the API server with Docker:
80
+
81
+ ```sh
82
+ docker build -t embedding-api .
83
+ docker run -p 7860:7860 embedding-api
84
+ ```
85
+ ### Test the endpoint
86
+ Test the endpoint with either:
87
+ ```sh
88
+ curl -X 'POST' \
89
+ 'http://127.0.0.1:7860/embed' \
90
+ -H 'accept: application/json' \
91
+ -H 'Content-Type: application/json' \
92
+ -d '{
93
+ "texts": [
94
+ "query: what is the capital of France?",
95
+ "passage: Paris is the capital of France."
96
+ ]
97
+ }'
98
+ ```
99
+ Or through the Swagger UI.
100
+
101
+
102
+
103
+ ## Usage
104
+
105
+ ### Embed Endpoint
106
+
107
+ - **POST** `/embed`
108
+ - **Request Body:**
109
+ ```json
110
+ {
111
+ "texts": [
112
+ "query: what is the capital of France?",
113
+ "passage: Paris is the capital of France."
114
+ ]
115
+ }
116
+ ```
117
+ - **Response:**
118
+ ```json
119
+ {
120
+ "embeddings": [[...], [...]]
121
+ }
122
+ ```
123
+
124
+ ### Health Endpoint
125
+
126
+ - **GET** `/health`
127
+ - **Response:**
128
+ ```json
129
+ {
130
+ "status": "ok"
131
+ }
132
+ ```
133
+
134
+ ## Project Structure
135
+
136
+ ```
137
+ app/
138
+ main.py # FastAPI app
139
+ embeddings.py # Embedding logic
140
+ models.py # Request/response models
141
+ logger.py # Logging setup
142
+ tests/
143
+ test_api.py # API tests
144
+ test_embeddings.py # Embedding tests
145
+ ```
app/__init__.py ADDED
File without changes
app/embeddings.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ from torch import Tensor
3
+ from app.logger import logger
4
+
5
+ model = AutoModel.from_pretrained("intfloat/multilingual-e5-large")
6
+ tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-large")
7
+
8
+
9
+ def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
10
+ """Average pool the token embeddings."""
11
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
12
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
13
+
14
+
15
+ def embed_text(texts: list[str]) -> list[list[float]]:
16
+ """
17
+ Generate embeddings for a list of texts.
18
+
19
+ The model supports a maximum of 512 tokens per input which typically corresponds to about 2000-2500 characters.
20
+ To avoid losing important information, we set a limit of 2000 characters per input text.
21
+ """
22
+ if not texts:
23
+ raise ValueError("No input texts provided.")
24
+ if any(len(text) > 2000 for text in texts):
25
+ raise ValueError(
26
+ "One or more input texts exceed the maximum length of 2000 characters."
27
+ )
28
+
29
+ batch_dict = tokenizer(
30
+ texts, max_length=512, padding=True, truncation=True, return_tensors="pt"
31
+ )
32
+ logger.info(
33
+ f"Tokenized {len(texts)} texts with number of tokens per text: {batch_dict['input_ids'].ne(tokenizer.pad_token_id).sum(dim=1).tolist()}"
34
+ )
35
+ outputs = model(**batch_dict)
36
+ embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
37
+
38
+ return embeddings.detach().cpu().tolist()
app/logger.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from logging.config import dictConfig
3
+
4
+ LOGGING_CONFIG = {
5
+ "version": 1,
6
+ "disable_existing_loggers": False,
7
+ "formatters": {
8
+ "default": {
9
+ "format": "[%(asctime)s] [%(levelname)s] %(name)s: %(message)s",
10
+ "datefmt": "%Y-%m-%d %H:%M:%S",
11
+ },
12
+ "json": {
13
+ "format": (
14
+ '{"time": "%(asctime)s", '
15
+ '"level": "%(levelname)s", '
16
+ '"name": "%(name)s", '
17
+ '"message": "%(message)s"}'
18
+ ),
19
+ "datefmt": "%Y-%m-%d %H:%M:%S",
20
+ },
21
+ },
22
+ "handlers": {
23
+ "console": {
24
+ "class": "logging.StreamHandler",
25
+ "formatter": "default",
26
+ },
27
+ },
28
+ "root": {
29
+ "level": "INFO",
30
+ "handlers": ["console"],
31
+ },
32
+ }
33
+
34
+ dictConfig(LOGGING_CONFIG)
35
+
36
+ logger = logging.getLogger("app")
app/main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import FileResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from app.models import EmbedRequest, EmbedResponse
6
+ from app.embeddings import embed_text
7
+ from logging import getLogger
8
+ from pathlib import Path
9
+
10
+ logger = getLogger(__name__)
11
+
12
+ app = FastAPI(
13
+ title="Embedding API",
14
+ description="A simple API to generate text embeddings using Microsoft's `multilingual-e5-large` model.",
15
+ version="1.0.0",
16
+ )
17
+
18
+ # Mount the frontend directory as static files under /static
19
+ FRONTEND_DIR = Path(__file__).resolve().parents[1] / "frontend"
20
+ if FRONTEND_DIR.exists():
21
+ app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static")
22
+
23
+ # Allow simple cross-origin requests from local development (adjust in production)
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"],
27
+ allow_credentials=True,
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+
33
+ @app.post("/embed", response_model=EmbedResponse)
34
+ async def embed(request: EmbedRequest) -> dict[str, list[list[float]]]:
35
+ """Generate embeddings for a list of texts."""
36
+ try:
37
+ vectors = embed_text(request.texts)
38
+ return {"embeddings": vectors}
39
+ except Exception as e:
40
+ logger.exception("Error generating embeddings")
41
+ raise HTTPException(status_code=500, detail=str(e))
42
+
43
+
44
+ @app.get("/health")
45
+ async def health_check() -> dict[str, str]:
46
+ """Health check endpoint."""
47
+ return {"status": "ok"}
48
+
49
+
50
+ @app.get("/", response_model=None)
51
+ async def root() -> FileResponse | dict[str, str]:
52
+ """Serve the frontend `index.html` if present, otherwise return small JSON status."""
53
+ index_file = FRONTEND_DIR / "index.html"
54
+ if index_file.exists():
55
+ return FileResponse(str(index_file))
56
+ return {"status": "ok", "message": "Frontend not found"}
app/models.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field, field_validator, StringConstraints
2
+ from typing import Annotated
3
+
4
+ PREFIX_ACCEPTED = ["query: ", "passage: "]
5
+
6
+ ShortText = Annotated[str, StringConstraints(max_length=2000)]
7
+
8
+
9
+ class EmbedRequest(BaseModel):
10
+ """
11
+ Request model for texts to be embedded.
12
+ Each text must start with an accepted prefix and be ≤ 2000 characters.
13
+ The texts need to start with either "query: " or "passage: ".
14
+ """
15
+
16
+ texts: list[ShortText] = Field(
17
+ ...,
18
+ json_schema_extra={
19
+ "example": [
20
+ "query: what is the capital of France?",
21
+ "passage: Paris is the capital of France.",
22
+ ]
23
+ },
24
+ description="List of texts to be embedded (≤ 2000 characters each) and must start with 'query: ' or 'passage: '.",
25
+ )
26
+
27
+ @field_validator("texts")
28
+ @classmethod
29
+ def check_prefixes(cls, texts: list[str]) -> list[str]:
30
+ for t in texts:
31
+ if not any(t.startswith(prefix) for prefix in PREFIX_ACCEPTED):
32
+ raise ValueError(f"Each text must start with one of {PREFIX_ACCEPTED}")
33
+ return texts
34
+
35
+
36
+ class EmbedResponse(BaseModel):
37
+ """Response model containing embeddings."""
38
+
39
+ embeddings: list[list[float]] = Field(
40
+ ...,
41
+ description="List of embedding vectors corresponding to the input texts. Each embedding is a list of floats with length 1024.",
42
+ )
frontend/index.html ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Text2Vector | Text to Vector Conversion</title>
7
+ <link rel="icon" type="image/x-icon" href="/static/favicon.ico">
8
+ <script src="https://cdn.tailwindcss.com"></script>
9
+ <script src="https://unpkg.com/feather-icons"></script>
10
+ <script src="https://cdn.jsdelivr.net/npm/feather-icons/dist/feather.min.js"></script>
11
+ <style>
12
+ .text-input-group:hover .remove-text-btn {
13
+ display: block !important;
14
+ }
15
+ .gradient-bg {
16
+ background: linear-gradient(135deg, #6e8efb 0%, #a777e3 100%);
17
+ }
18
+ .embed-card {
19
+ backdrop-filter: blur(10px);
20
+ background: rgba(255, 255, 255, 0.1);
21
+ border: 1px solid rgba(255, 255, 255, 0.2);
22
+ }
23
+ .text-area {
24
+ min-height: 150px;
25
+ }
26
+ .vector-display {
27
+ font-family: monospace;
28
+ white-space: pre-wrap;
29
+ overflow-x: auto;
30
+ }
31
+ #vanta-bg {
32
+ position: absolute;
33
+ top: 0;
34
+ left: 0;
35
+ width: 100%;
36
+ height: 100%;
37
+ z-index: -1;
38
+ }
39
+ </style>
40
+ </head>
41
+ <body class="min-h-screen text-gray-100">
42
+ <div id="vanta-bg"></div>
43
+ <div class="container mx-auto px-4 py-12">
44
+ <!-- Header -->
45
+ <header class="text-center mb-12">
46
+ <h1 class="text-4xl md:text-5xl font-bold mb-4">Text2Vector ⚡</h1>
47
+ <p class="text-xl opacity-80">Transform your text into powerful vector embeddings</p>
48
+ </header>
49
+
50
+ <!-- Main Content -->
51
+ <main class="max-w-4xl mx-auto">
52
+ <div class="grid md:grid-cols-2 gap-8">
53
+ <!-- Input Section -->
54
+ <div class="embed-card rounded-xl p-6 shadow-lg">
55
+ <div class="flex items-center justify-between mb-4">
56
+ <div class="flex items-center">
57
+ <i data-feather="edit-3" class="mr-2"></i>
58
+ <h2 class="text-xl font-semibold">Input Texts</h2>
59
+ </div>
60
+ <button id="add-text-btn" class="px-3 py-1 gradient-bg hover:opacity-90 rounded-lg transition flex items-center text-sm">
61
+ <i data-feather="plus" class="mr-1"></i>
62
+ Add Field
63
+ </button>
64
+ </div>
65
+ <div id="text-inputs-container">
66
+ <div class="text-input-group mb-3 relative">
67
+ <textarea class="w-full text-area bg-gray-800 bg-opacity-50 rounded-lg p-4 text-white border border-gray-600 focus:border-purple-400 focus:ring-1 focus:ring-purple-400 transition" placeholder="Enter your text here..."></textarea>
68
+ <button class="remove-text-btn absolute top-1 right-1 p-1 bg-gray-700 hover:bg-gray-600 rounded-full transition" style="display: none;">
69
+ <i data-feather="x" class="w-4 h-4"></i>
70
+ </button>
71
+ </div>
72
+ </div>
73
+ <div class="flex justify-between mt-4">
74
+ <button id="clear-btn" class="px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition flex items-center">
75
+ <i data-feather="trash-2" class="mr-2"></i>
76
+ Clear All
77
+ </button>
78
+ <button id="generate-btn" class="px-6 py-2 gradient-bg hover:opacity-90 rounded-lg transition flex items-center">
79
+ <i data-feather="zap" class="mr-2"></i>
80
+ Generate Embeddings
81
+ </button>
82
+ </div>
83
+ </div>
84
+ <!-- Output Section -->
85
+ <div class="embed-card rounded-xl p-6 shadow-lg">
86
+ <div class="flex items-center justify-between mb-4">
87
+ <div class="flex items-center">
88
+ <i data-feather="list" class="mr-2"></i>
89
+ <h2 class="text-xl font-semibold">Vector Embeddings</h2>
90
+ </div>
91
+ <button id="copy-btn" class="px-3 py-1 bg-gray-700 hover:bg-gray-600 rounded-lg transition flex items-center text-sm" disabled>
92
+ <i data-feather="copy" class="mr-1"></i>
93
+ Copy
94
+ </button>
95
+ </div>
96
+ <div id="output-container" class="vector-display bg-gray-800 bg-opacity-50 rounded-lg p-4 h-64 overflow-auto hidden">
97
+ <pre id="output-vector" class="text-sm"></pre>
98
+ </div>
99
+ <div id="placeholder" class="bg-gray-800 bg-opacity-30 rounded-lg p-8 text-center h-64 flex items-center justify-center">
100
+ <div class="opacity-60">
101
+ <i data-feather="wind" class="w-12 h-12 mx-auto mb-4"></i>
102
+ <p>Your embeddings will appear here</p>
103
+ </div>
104
+ </div>
105
+ <button id="download-btn" class="w-full mt-4 px-4 py-2 gradient-bg hover:opacity-90 rounded-lg transition flex items-center justify-center hidden">
106
+ <i data-feather="download" class="mr-2"></i>
107
+ Download as JSON
108
+ </button>
109
+ </div>
110
+ </div>
111
+
112
+ <!-- Info Section -->
113
+ <div class="embed-card rounded-xl p-6 mt-8 shadow-lg">
114
+ <div class="flex items-center mb-4">
115
+ <i data-feather="info" class="mr-2"></i>
116
+ <h2 class="text-xl font-semibold">About Text2Vector</h2>
117
+ </div>
118
+ <p class="mb-4">Text2Vector transforms your text into high-dimensional vector representations using Microsoft's multilingual-e5-large model, capturing semantic meaning across multiple languages.</p>
119
+ <div class="grid md:grid-cols-3 gap-4">
120
+ <div class="bg-gray-800 bg-opacity-30 p-4 rounded-lg">
121
+ <div class="flex items-center mb-2">
122
+ <i data-feather="hash" class="mr-2 text-purple-300"></i>
123
+ <h3 class="font-medium">Dimensionality</h3>
124
+ </div>
125
+ <p class="text-sm opacity-80">1024-dimensional vectors</p>
126
+ </div>
127
+ <div class="bg-gray-800 bg-opacity-30 p-4 rounded-lg">
128
+ <div class="flex items-center mb-2">
129
+ <i data-feather="cpu" class="mr-2 text-purple-300"></i>
130
+ <h3 class="font-medium">Model</h3>
131
+ </div>
132
+ <p class="text-sm opacity-80">Microsoft's multilingual-e5-large</p>
133
+ </div>
134
+ <div class="bg-gray-800 bg-opacity-30 p-4 rounded-lg">
135
+ <div class="flex items-center mb-2">
136
+ <i data-feather="code" class="mr-2 text-purple-300"></i>
137
+ <h3 class="font-medium">API</h3>
138
+ </div>
139
+ <p class="text-sm opacity-80">Simple REST integration</p>
140
+ </div>
141
+ </div>
142
+ </div>
143
+ </main>
144
+ </div>
145
+ <!-- Footer -->
146
+ <footer class="text-center py-8 opacity-70 text-sm">
147
+ <p>© 2023 Text2Vector ⚡ | Powered by multilingual-e5-large embeddings</p>
148
+ <p class="mt-2" id="api-status">API Status: <span class="text-red-500">Checking...</span></p>
149
+ </footer>
150
+ <!-- Scripts -->
151
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r121/three.min.js"></script>
152
+ <script src="https://cdn.jsdelivr.net/npm/vanta@latest/dist/vanta.globe.min.js"></script>
153
+ <script>
154
+ // Initialize Vanta.js background
155
+ VANTA.GLOBE({
156
+ el: "#vanta-bg",
157
+ mouseControls: true,
158
+ touchControls: true,
159
+ gyroControls: false,
160
+ minHeight: 200.00,
161
+ minWidth: 200.00,
162
+ scale: 1.00,
163
+ scaleMobile: 1.00,
164
+ color: 0x6e8efb,
165
+ backgroundColor: 0x0,
166
+ size: 0.8
167
+ });
168
+
169
+ // Initialize Feather Icons
170
+ feather.replace();
171
+ // DOM Elements
172
+ const textInputsContainer = document.getElementById('text-inputs-container');
173
+ const generateBtn = document.getElementById('generate-btn');
174
+ const clearBtn = document.getElementById('clear-btn');
175
+ const copyBtn = document.getElementById('copy-btn');
176
+ const downloadBtn = document.getElementById('download-btn');
177
+ const outputContainer = document.getElementById('output-container');
178
+ const outputVector = document.getElementById('output-vector');
179
+ const placeholder = document.getElementById('placeholder');
180
+ const addTextBtn = document.getElementById('add-text-btn');
181
+
182
+ // Add new text input field
183
+ addTextBtn.addEventListener('click', () => {
184
+ const newInputGroup = document.createElement('div');
185
+ newInputGroup.className = 'text-input-group mb-3 relative';
186
+ newInputGroup.innerHTML = `
187
+ <textarea class="w-full text-area bg-gray-800 bg-opacity-50 rounded-lg p-4 text-white border border-gray-600 focus:border-purple-400 focus:ring-1 focus:ring-purple-400 transition" placeholder="Enter your text here..."></textarea>
188
+ <button class="remove-text-btn absolute top-1 right-1 p-1 bg-gray-700 hover:bg-gray-600 rounded-full transition">
189
+ <i data-feather="x" class="w-4 h-4"></i>
190
+ </button>
191
+ `;
192
+ textInputsContainer.appendChild(newInputGroup);
193
+ feather.replace();
194
+ setupRemoveButtons();
195
+ });
196
+
197
+ // Setup remove buttons for all input fields
198
+ function setupRemoveButtons() {
199
+ document.querySelectorAll('.remove-text-btn').forEach(btn => {
200
+ btn.addEventListener('click', (e) => {
201
+ if (document.querySelectorAll('.text-input-group').length > 1) {
202
+ e.target.closest('.text-input-group').remove();
203
+ }
204
+ });
205
+ });
206
+ }
207
+
208
+ // Show remove buttons when hovering over input groups
209
+ document.addEventListener('mouseover', (e) => {
210
+ if (e.target.closest('.text-input-group')) {
211
+ const group = e.target.closest('.text-input-group');
212
+ if (document.querySelectorAll('.text-input-group').length > 1) {
213
+ group.querySelector('.remove-text-btn').style.display = 'block';
214
+ }
215
+ }
216
+ });
217
+
218
+ document.addEventListener('mouseout', (e) => {
219
+ if (e.target.closest('.text-input-group')) {
220
+ const group = e.target.closest('.text-input-group');
221
+ group.querySelector('.remove-text-btn').style.display = 'none';
222
+ }
223
+ });
224
+
225
+ setupRemoveButtons();
226
+ // API Configuration
227
+ const API_URL = '/embed';
228
+ // Check API health
229
+ async function checkApiHealth() {
230
+ try {
231
+ const response = await fetch('/health');
232
+ if (response.ok) {
233
+ document.getElementById('api-status').innerHTML =
234
+ 'API Status: <span class="text-green-500">Online</span>';
235
+ } else {
236
+ throw new Error('API not responding');
237
+ }
238
+ } catch (error) {
239
+ document.getElementById('api-status').innerHTML =
240
+ 'API Status: <span class="text-red-500">Offline</span>';
241
+ console.error('API health check failed:', error);
242
+ }
243
+ }
244
+
245
+ // Initial health check
246
+ checkApiHealth();
247
+ setInterval(checkApiHealth, 30000); // Check every 30 seconds
248
+
249
+ // Event Listeners
250
+ generateBtn.addEventListener('click', async () => {
251
+ const textInputs = Array.from(document.querySelectorAll('.text-input-group textarea'));
252
+ const texts = textInputs
253
+ .map(input => input.value.trim())
254
+ .filter(text => text.length > 0);
255
+
256
+ if (texts.length === 0) {
257
+ alert('Please enter at least one text input');
258
+ return;
259
+ }
260
+ // Show loading state
261
+ generateBtn.disabled = true;
262
+ generateBtn.innerHTML = '<i data-feather="loader" class="animate-spin mr-2"></i> Processing...';
263
+ feather.replace();
264
+ // Make API call to backend
265
+ try {
266
+ const response = await fetch('/embed', {
267
+ method: 'POST',
268
+ headers: {
269
+ 'Content-Type': 'application/json',
270
+ },
271
+ body: JSON.stringify({
272
+ texts: texts
273
+ })
274
+ });
275
+ if (!response.ok) {
276
+ throw new Error(`API Error: ${response.status}`);
277
+ }
278
+ const data = await response.json();
279
+ const embeddings = texts.map((text, index) => ({
280
+ text: text,
281
+ vector: data.embeddings[index],
282
+ model: "multilingual-e5-large",
283
+ timestamp: new Date().toISOString()
284
+ }));
285
+
286
+ // Display the embeddings
287
+ outputVector.textContent = JSON.stringify(embeddings.length === 1 ? embeddings[0] : embeddings, null, 2);
288
+ placeholder.classList.add('hidden');
289
+ outputContainer.classList.remove('hidden');
290
+ copyBtn.disabled = false;
291
+ downloadBtn.classList.remove('hidden');
292
+
293
+ } catch (error) {
294
+ alert(`Failed to generate embeddings: ${error.message}`);
295
+ console.error(error);
296
+ }
297
+
298
+ // Reset button
299
+ generateBtn.disabled = false;
300
+ generateBtn.innerHTML = '<i data-feather="zap" class="mr-2"></i> Generate Embeddings';
301
+ feather.replace();
302
+ });
303
+
304
+ clearBtn.addEventListener('click', () => {
305
+ document.querySelectorAll('.text-input-group textarea').forEach(input => {
306
+ input.value = '';
307
+ });
308
+ outputVector.textContent = '';
309
+ outputContainer.classList.add('hidden');
310
+ placeholder.classList.remove('hidden');
311
+ copyBtn.disabled = true;
312
+ downloadBtn.classList.add('hidden');
313
+ });
314
+
315
+ copyBtn.addEventListener('click', () => {
316
+ navigator.clipboard.writeText(outputVector.textContent)
317
+ .then(() => {
318
+ copyBtn.innerHTML = '<i data-feather="check" class="mr-1"></i> Copied!';
319
+ feather.replace();
320
+ setTimeout(() => {
321
+ copyBtn.innerHTML = '<i data-feather="copy" class="mr-1"></i> Copy';
322
+ feather.replace();
323
+ }, 2000);
324
+ });
325
+ });
326
+
327
+ downloadBtn.addEventListener('click', () => {
328
+ const blob = new Blob([outputVector.textContent], { type: 'application/json' });
329
+ const url = URL.createObjectURL(blob);
330
+ const a = document.createElement('a');
331
+ a.href = url;
332
+ a.download = `embedding-${new Date().getTime()}.json`;
333
+ document.body.appendChild(a);
334
+ a.click();
335
+ document.body.removeChild(a);
336
+ URL.revokeObjectURL(url);
337
+ });
338
+ </script>
339
+ </body>
340
+ </html>
pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "text2vector"
3
+ version = "1.0.0"
4
+ description = "API to call an embedding model"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+
8
+ dependencies = [
9
+ "black>=25.9.0",
10
+ "fastapi>=0.119.0",
11
+ "httpx>=0.28.1",
12
+ "mypy>=1.18.2",
13
+ "pydantic>=2.12.0",
14
+ "pytest>=8.4.2",
15
+ "ruff>=0.14.0",
16
+ "torch>=2.8.0",
17
+ "transformers>=4.57.0",
18
+ "uvicorn>=0.37.0",
19
+ ]
20
+
21
+ # https://quantlane.com/blog/type-checking-large-codebase/
22
+ [tool.mypy]
23
+ # Ensure full coverage
24
+ disallow_untyped_calls = false
25
+ disallow_untyped_defs = true
26
+ disallow_incomplete_defs = true
27
+ disallow_untyped_decorators = false
28
+ check_untyped_defs = true
29
+
30
+ # Restrict dynamic typing
31
+ disallow_any_generics = false
32
+ disallow_subclassing_any = false
33
+ warn_return_any = false
34
+
35
+ # Know exactly what you're doing
36
+ warn_redundant_casts = true
37
+ warn_unused_ignores = false
38
+ warn_unused_configs = true
39
+ warn_unreachable = true
40
+ show_error_codes = true
41
+
42
+ # Explicit is better than implicit
43
+ no_implicit_optional = true
tests/__init__.py ADDED
File without changes
tests/test_api.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.testclient import TestClient
2
+ from app.main import app
3
+
4
+ client = TestClient(app)
5
+
6
+
7
+ def test_embed() -> None:
8
+ """Test the /embed endpoint with valid input."""
9
+ response = client.post("/embed", json={"texts": ["query: Hello world"]})
10
+ assert response.status_code == 200 # OK
11
+ data = response.json()
12
+ assert "embeddings" in data
13
+ assert len(data["embeddings"][0]) == 1024
14
+
15
+
16
+ def test_embed_no_texts() -> None:
17
+ """Test the /embed endpoint with no texts provided."""
18
+ response = client.post("/embed", json={})
19
+ assert response.status_code == 422 # Unprocessable Entity
20
+
21
+
22
+ def test_embed_long_text() -> None:
23
+ """Test the /embed endpoint with a text longer than 2000 characters."""
24
+ long_text = "query: " + "a" * 1994 # 2001 characters
25
+ response = client.post("/embed", json={"texts": [long_text]})
26
+ assert response.status_code == 422 # Unprocessable Entity
tests/test_embeddings.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.embeddings import average_pool, embed_text
2
+ import torch
3
+ import pytest
4
+
5
+
6
+ def test_average_pool_basic() -> None:
7
+ """Test average pooling produces correct shape and masking."""
8
+ last_hidden_states = torch.tensor(
9
+ [
10
+ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
11
+ [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]],
12
+ ]
13
+ ) # shape: (2, 3, 2)
14
+ attention_mask = torch.tensor(
15
+ [
16
+ [1, 1, 0],
17
+ [1, 0, 0],
18
+ ]
19
+ ) # shape: (2, 3)
20
+
21
+ result = average_pool(last_hidden_states, attention_mask)
22
+
23
+ # Expected averages:
24
+ # row1: [(1+3)/2, (2+4)/2] = [2,3]
25
+ # row2: [10, 20]
26
+ expected = torch.tensor([[2.0, 3.0], [10.0, 20.0]])
27
+
28
+ assert torch.allclose(result, expected, atol=1e-6)
29
+ assert result.shape == (2, 2)
30
+
31
+
32
+ def test_embed_text_valid() -> None:
33
+ """Test embedding returns correct number of vectors and dimensions."""
34
+
35
+ texts = ["query: Hello world", "query: Hej verden"]
36
+ embeddings = embed_text(texts)
37
+
38
+ # Assertions
39
+ assert isinstance(embeddings, list)
40
+ assert len(embeddings) == len(texts)
41
+ assert all(isinstance(vec, list) for vec in embeddings)
42
+ assert all(isinstance(x, float) for x in embeddings[0])
43
+ assert len(embeddings[0]) == 1024
44
+
45
+
46
+ def test_embed_text_empty_list() -> None:
47
+ """Should raise ValueError if no input texts."""
48
+ with pytest.raises(ValueError, match="No input texts provided"):
49
+ embed_text([])
50
+
51
+
52
+ def test_embed_text_too_long() -> None:
53
+ """Should raise ValueError for inputs exceeding 2000 characters."""
54
+ too_long = ["query: " + "a" * 1994] # 2001 characters
55
+ with pytest.raises(ValueError, match="exceed the maximum length"):
56
+ embed_text(too_long)
uv.lock ADDED
The diff for this file is too large to render. See raw diff