Spaces:
Sleeping
Sleeping
Commit ·
5a5e912
0
Parent(s):
init project
Browse files- .gitignore +4 -0
- Dockerfile +15 -0
- Makefile +15 -0
- README.md +145 -0
- app/__init__.py +0 -0
- app/embeddings.py +38 -0
- app/logger.py +36 -0
- app/main.py +56 -0
- app/models.py +42 -0
- frontend/index.html +340 -0
- pyproject.toml +43 -0
- tests/__init__.py +0 -0
- tests/test_api.py +26 -0
- tests/test_embeddings.py +56 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
.venv
|
Dockerfile
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
RUN useradd -m -u 1000 user
|
| 4 |
+
USER user
|
| 5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
COPY --chown=user pyproject.toml ./
|
| 10 |
+
RUN pip install --no-cache-dir .
|
| 11 |
+
|
| 12 |
+
COPY --chown=user app ./app
|
| 13 |
+
|
| 14 |
+
# Start the app with Uvicorn
|
| 15 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
Makefile
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
APP_DIR := $(CURDIR)/app
|
| 2 |
+
TESTS_DIR := $(CURDIR)/tests
|
| 3 |
+
|
| 4 |
+
format:
|
| 5 |
+
uv run black $(APP_DIR) $(TESTS_DIR)/*.py
|
| 6 |
+
uv run ruff check $(APP_DIR) $(TESTS_DIR) --fix
|
| 7 |
+
|
| 8 |
+
lint:
|
| 9 |
+
uv run black --check $(APP_DIR) $(TESTS_DIR)/*.py
|
| 10 |
+
uv run ruff check $(APP_DIR) $(TESTS_DIR)
|
| 11 |
+
uv run mypy $(APP_DIR) $(TESTS_DIR)
|
| 12 |
+
|
| 13 |
+
test:
|
| 14 |
+
uv run pytest ${TESTS_DIR}
|
| 15 |
+
|
README.md
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Text2vector
|
| 3 |
+
emoji: 📊
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
short_description: Create a vector embedding from text
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Embedding API
|
| 12 |
+
|
| 13 |
+
API to call an embedding model ([intfloat/multilingual-e5-large](https://huggingface.co/intfloat/multilingual-e5-large)) for generating multilingual text embeddings.<br>
|
| 14 |
+
The embedding model takes a text string and converts it into 1024 dimension vector.<br>
|
| 15 |
+
Using a `POST` request to the `/embed` endpoint with a list of texts, the API returns their corresponding embeddings.<br>
|
| 16 |
+
A maximum of 2000 characters per text is enforced to avoid truncation, and thereby loss of information, by the tokenizer.<br>
|
| 17 |
+
Each text must start with either "query: " or "passage: ".<br>
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
The API is deployed at a Hugging Face Docker space where the Swagger UI can be acccessed at:<br>
|
| 21 |
+
[https://emilbm-text2vector.hf.space/docs](https://emilbm-text2vector.hf.space/docs)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
## Features
|
| 25 |
+
|
| 26 |
+
- FastAPI-based REST API
|
| 27 |
+
- `/embed` endpoint for generating embeddings from a list of texts
|
| 28 |
+
- `/health` endpoint for checking the API status
|
| 29 |
+
- Uses HuggingFace Transformers and PyTorch
|
| 30 |
+
- Includes linting and unit tests
|
| 31 |
+
- Dockerfile for containerization
|
| 32 |
+
- CI/CD with GitHub Actions to build, lint, test, and deploy to Hugging Face
|
| 33 |
+
|
| 34 |
+
## Local Development
|
| 35 |
+
### Requirements
|
| 36 |
+
|
| 37 |
+
- Python 3.12+
|
| 38 |
+
- [UV](https://docs.astral.sh/uv/)
|
| 39 |
+
- (Optional) Docker
|
| 40 |
+
### Installation
|
| 41 |
+
|
| 42 |
+
1. **Clone the repository:**
|
| 43 |
+
```sh
|
| 44 |
+
git clone https://github.com/EmilbMadsen/embedding-api.git
|
| 45 |
+
cd embedding-api
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
2. **Create a virtual environment and activate it:**
|
| 49 |
+
```sh
|
| 50 |
+
uv venv
|
| 51 |
+
source .venv/bin/activate
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
3. **Install dependencies:**
|
| 55 |
+
```sh
|
| 56 |
+
uv sync
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### Formatting, Linting and Unit Tests
|
| 60 |
+
- **Formatting (with Black and Ruff) and linting (with Black, Ruff, and MyPy):**
|
| 61 |
+
```sh
|
| 62 |
+
make format
|
| 63 |
+
make lint
|
| 64 |
+
```
|
| 65 |
+
- **Run unit tests:**
|
| 66 |
+
```sh
|
| 67 |
+
make test
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
### Running Locally (without Docker)
|
| 71 |
+
|
| 72 |
+
Start the API server with Uvicorn:
|
| 73 |
+
|
| 74 |
+
```sh
|
| 75 |
+
uvicorn app.main:app --reload --port 7860
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Running Locally (with Docker)
|
| 79 |
+
Build and start the API server with Docker:
|
| 80 |
+
|
| 81 |
+
```sh
|
| 82 |
+
docker build -t embedding-api .
|
| 83 |
+
docker run -p 7860:7860 embedding-api
|
| 84 |
+
```
|
| 85 |
+
### Test the endpoint
|
| 86 |
+
Test the endpoint with either:
|
| 87 |
+
```sh
|
| 88 |
+
curl -X 'POST' \
|
| 89 |
+
'http://127.0.0.1:7860/embed' \
|
| 90 |
+
-H 'accept: application/json' \
|
| 91 |
+
-H 'Content-Type: application/json' \
|
| 92 |
+
-d '{
|
| 93 |
+
"texts": [
|
| 94 |
+
"query: what is the capital of France?",
|
| 95 |
+
"passage: Paris is the capital of France."
|
| 96 |
+
]
|
| 97 |
+
}'
|
| 98 |
+
```
|
| 99 |
+
Or through the Swagger UI.
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
## Usage
|
| 104 |
+
|
| 105 |
+
### Embed Endpoint
|
| 106 |
+
|
| 107 |
+
- **POST** `/embed`
|
| 108 |
+
- **Request Body:**
|
| 109 |
+
```json
|
| 110 |
+
{
|
| 111 |
+
"texts": [
|
| 112 |
+
"query: what is the capital of France?",
|
| 113 |
+
"passage: Paris is the capital of France."
|
| 114 |
+
]
|
| 115 |
+
}
|
| 116 |
+
```
|
| 117 |
+
- **Response:**
|
| 118 |
+
```json
|
| 119 |
+
{
|
| 120 |
+
"embeddings": [[...], [...]]
|
| 121 |
+
}
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### Health Endpoint
|
| 125 |
+
|
| 126 |
+
- **GET** `/health`
|
| 127 |
+
- **Response:**
|
| 128 |
+
```json
|
| 129 |
+
{
|
| 130 |
+
"status": "ok"
|
| 131 |
+
}
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
## Project Structure
|
| 135 |
+
|
| 136 |
+
```
|
| 137 |
+
app/
|
| 138 |
+
main.py # FastAPI app
|
| 139 |
+
embeddings.py # Embedding logic
|
| 140 |
+
models.py # Request/response models
|
| 141 |
+
logger.py # Logging setup
|
| 142 |
+
tests/
|
| 143 |
+
test_api.py # API tests
|
| 144 |
+
test_embeddings.py # Embedding tests
|
| 145 |
+
```
|
app/__init__.py
ADDED
|
File without changes
|
app/embeddings.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModel
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
from app.logger import logger
|
| 4 |
+
|
| 5 |
+
model = AutoModel.from_pretrained("intfloat/multilingual-e5-large")
|
| 6 |
+
tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-large")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
| 10 |
+
"""Average pool the token embeddings."""
|
| 11 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 12 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def embed_text(texts: list[str]) -> list[list[float]]:
|
| 16 |
+
"""
|
| 17 |
+
Generate embeddings for a list of texts.
|
| 18 |
+
|
| 19 |
+
The model supports a maximum of 512 tokens per input which typically corresponds to about 2000-2500 characters.
|
| 20 |
+
To avoid losing important information, we set a limit of 2000 characters per input text.
|
| 21 |
+
"""
|
| 22 |
+
if not texts:
|
| 23 |
+
raise ValueError("No input texts provided.")
|
| 24 |
+
if any(len(text) > 2000 for text in texts):
|
| 25 |
+
raise ValueError(
|
| 26 |
+
"One or more input texts exceed the maximum length of 2000 characters."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
batch_dict = tokenizer(
|
| 30 |
+
texts, max_length=512, padding=True, truncation=True, return_tensors="pt"
|
| 31 |
+
)
|
| 32 |
+
logger.info(
|
| 33 |
+
f"Tokenized {len(texts)} texts with number of tokens per text: {batch_dict['input_ids'].ne(tokenizer.pad_token_id).sum(dim=1).tolist()}"
|
| 34 |
+
)
|
| 35 |
+
outputs = model(**batch_dict)
|
| 36 |
+
embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
|
| 37 |
+
|
| 38 |
+
return embeddings.detach().cpu().tolist()
|
app/logger.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from logging.config import dictConfig
|
| 3 |
+
|
| 4 |
+
LOGGING_CONFIG = {
|
| 5 |
+
"version": 1,
|
| 6 |
+
"disable_existing_loggers": False,
|
| 7 |
+
"formatters": {
|
| 8 |
+
"default": {
|
| 9 |
+
"format": "[%(asctime)s] [%(levelname)s] %(name)s: %(message)s",
|
| 10 |
+
"datefmt": "%Y-%m-%d %H:%M:%S",
|
| 11 |
+
},
|
| 12 |
+
"json": {
|
| 13 |
+
"format": (
|
| 14 |
+
'{"time": "%(asctime)s", '
|
| 15 |
+
'"level": "%(levelname)s", '
|
| 16 |
+
'"name": "%(name)s", '
|
| 17 |
+
'"message": "%(message)s"}'
|
| 18 |
+
),
|
| 19 |
+
"datefmt": "%Y-%m-%d %H:%M:%S",
|
| 20 |
+
},
|
| 21 |
+
},
|
| 22 |
+
"handlers": {
|
| 23 |
+
"console": {
|
| 24 |
+
"class": "logging.StreamHandler",
|
| 25 |
+
"formatter": "default",
|
| 26 |
+
},
|
| 27 |
+
},
|
| 28 |
+
"root": {
|
| 29 |
+
"level": "INFO",
|
| 30 |
+
"handlers": ["console"],
|
| 31 |
+
},
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
dictConfig(LOGGING_CONFIG)
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger("app")
|
app/main.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
+
from fastapi.responses import FileResponse
|
| 3 |
+
from fastapi.staticfiles import StaticFiles
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from app.models import EmbedRequest, EmbedResponse
|
| 6 |
+
from app.embeddings import embed_text
|
| 7 |
+
from logging import getLogger
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
logger = getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
app = FastAPI(
|
| 13 |
+
title="Embedding API",
|
| 14 |
+
description="A simple API to generate text embeddings using Microsoft's `multilingual-e5-large` model.",
|
| 15 |
+
version="1.0.0",
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# Mount the frontend directory as static files under /static
|
| 19 |
+
FRONTEND_DIR = Path(__file__).resolve().parents[1] / "frontend"
|
| 20 |
+
if FRONTEND_DIR.exists():
|
| 21 |
+
app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static")
|
| 22 |
+
|
| 23 |
+
# Allow simple cross-origin requests from local development (adjust in production)
|
| 24 |
+
app.add_middleware(
|
| 25 |
+
CORSMiddleware,
|
| 26 |
+
allow_origins=["*"],
|
| 27 |
+
allow_credentials=True,
|
| 28 |
+
allow_methods=["*"],
|
| 29 |
+
allow_headers=["*"],
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@app.post("/embed", response_model=EmbedResponse)
|
| 34 |
+
async def embed(request: EmbedRequest) -> dict[str, list[list[float]]]:
|
| 35 |
+
"""Generate embeddings for a list of texts."""
|
| 36 |
+
try:
|
| 37 |
+
vectors = embed_text(request.texts)
|
| 38 |
+
return {"embeddings": vectors}
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.exception("Error generating embeddings")
|
| 41 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@app.get("/health")
|
| 45 |
+
async def health_check() -> dict[str, str]:
|
| 46 |
+
"""Health check endpoint."""
|
| 47 |
+
return {"status": "ok"}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@app.get("/", response_model=None)
|
| 51 |
+
async def root() -> FileResponse | dict[str, str]:
|
| 52 |
+
"""Serve the frontend `index.html` if present, otherwise return small JSON status."""
|
| 53 |
+
index_file = FRONTEND_DIR / "index.html"
|
| 54 |
+
if index_file.exists():
|
| 55 |
+
return FileResponse(str(index_file))
|
| 56 |
+
return {"status": "ok", "message": "Frontend not found"}
|
app/models.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field, field_validator, StringConstraints
|
| 2 |
+
from typing import Annotated
|
| 3 |
+
|
| 4 |
+
PREFIX_ACCEPTED = ["query: ", "passage: "]
|
| 5 |
+
|
| 6 |
+
ShortText = Annotated[str, StringConstraints(max_length=2000)]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EmbedRequest(BaseModel):
|
| 10 |
+
"""
|
| 11 |
+
Request model for texts to be embedded.
|
| 12 |
+
Each text must start with an accepted prefix and be ≤ 2000 characters.
|
| 13 |
+
The texts need to start with either "query: " or "passage: ".
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
texts: list[ShortText] = Field(
|
| 17 |
+
...,
|
| 18 |
+
json_schema_extra={
|
| 19 |
+
"example": [
|
| 20 |
+
"query: what is the capital of France?",
|
| 21 |
+
"passage: Paris is the capital of France.",
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
description="List of texts to be embedded (≤ 2000 characters each) and must start with 'query: ' or 'passage: '.",
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
@field_validator("texts")
|
| 28 |
+
@classmethod
|
| 29 |
+
def check_prefixes(cls, texts: list[str]) -> list[str]:
|
| 30 |
+
for t in texts:
|
| 31 |
+
if not any(t.startswith(prefix) for prefix in PREFIX_ACCEPTED):
|
| 32 |
+
raise ValueError(f"Each text must start with one of {PREFIX_ACCEPTED}")
|
| 33 |
+
return texts
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class EmbedResponse(BaseModel):
|
| 37 |
+
"""Response model containing embeddings."""
|
| 38 |
+
|
| 39 |
+
embeddings: list[list[float]] = Field(
|
| 40 |
+
...,
|
| 41 |
+
description="List of embedding vectors corresponding to the input texts. Each embedding is a list of floats with length 1024.",
|
| 42 |
+
)
|
frontend/index.html
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Text2Vector | Text to Vector Conversion</title>
|
| 7 |
+
<link rel="icon" type="image/x-icon" href="/static/favicon.ico">
|
| 8 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
| 9 |
+
<script src="https://unpkg.com/feather-icons"></script>
|
| 10 |
+
<script src="https://cdn.jsdelivr.net/npm/feather-icons/dist/feather.min.js"></script>
|
| 11 |
+
<style>
|
| 12 |
+
.text-input-group:hover .remove-text-btn {
|
| 13 |
+
display: block !important;
|
| 14 |
+
}
|
| 15 |
+
.gradient-bg {
|
| 16 |
+
background: linear-gradient(135deg, #6e8efb 0%, #a777e3 100%);
|
| 17 |
+
}
|
| 18 |
+
.embed-card {
|
| 19 |
+
backdrop-filter: blur(10px);
|
| 20 |
+
background: rgba(255, 255, 255, 0.1);
|
| 21 |
+
border: 1px solid rgba(255, 255, 255, 0.2);
|
| 22 |
+
}
|
| 23 |
+
.text-area {
|
| 24 |
+
min-height: 150px;
|
| 25 |
+
}
|
| 26 |
+
.vector-display {
|
| 27 |
+
font-family: monospace;
|
| 28 |
+
white-space: pre-wrap;
|
| 29 |
+
overflow-x: auto;
|
| 30 |
+
}
|
| 31 |
+
#vanta-bg {
|
| 32 |
+
position: absolute;
|
| 33 |
+
top: 0;
|
| 34 |
+
left: 0;
|
| 35 |
+
width: 100%;
|
| 36 |
+
height: 100%;
|
| 37 |
+
z-index: -1;
|
| 38 |
+
}
|
| 39 |
+
</style>
|
| 40 |
+
</head>
|
| 41 |
+
<body class="min-h-screen text-gray-100">
|
| 42 |
+
<div id="vanta-bg"></div>
|
| 43 |
+
<div class="container mx-auto px-4 py-12">
|
| 44 |
+
<!-- Header -->
|
| 45 |
+
<header class="text-center mb-12">
|
| 46 |
+
<h1 class="text-4xl md:text-5xl font-bold mb-4">Text2Vector ⚡</h1>
|
| 47 |
+
<p class="text-xl opacity-80">Transform your text into powerful vector embeddings</p>
|
| 48 |
+
</header>
|
| 49 |
+
|
| 50 |
+
<!-- Main Content -->
|
| 51 |
+
<main class="max-w-4xl mx-auto">
|
| 52 |
+
<div class="grid md:grid-cols-2 gap-8">
|
| 53 |
+
<!-- Input Section -->
|
| 54 |
+
<div class="embed-card rounded-xl p-6 shadow-lg">
|
| 55 |
+
<div class="flex items-center justify-between mb-4">
|
| 56 |
+
<div class="flex items-center">
|
| 57 |
+
<i data-feather="edit-3" class="mr-2"></i>
|
| 58 |
+
<h2 class="text-xl font-semibold">Input Texts</h2>
|
| 59 |
+
</div>
|
| 60 |
+
<button id="add-text-btn" class="px-3 py-1 gradient-bg hover:opacity-90 rounded-lg transition flex items-center text-sm">
|
| 61 |
+
<i data-feather="plus" class="mr-1"></i>
|
| 62 |
+
Add Field
|
| 63 |
+
</button>
|
| 64 |
+
</div>
|
| 65 |
+
<div id="text-inputs-container">
|
| 66 |
+
<div class="text-input-group mb-3 relative">
|
| 67 |
+
<textarea class="w-full text-area bg-gray-800 bg-opacity-50 rounded-lg p-4 text-white border border-gray-600 focus:border-purple-400 focus:ring-1 focus:ring-purple-400 transition" placeholder="Enter your text here..."></textarea>
|
| 68 |
+
<button class="remove-text-btn absolute top-1 right-1 p-1 bg-gray-700 hover:bg-gray-600 rounded-full transition" style="display: none;">
|
| 69 |
+
<i data-feather="x" class="w-4 h-4"></i>
|
| 70 |
+
</button>
|
| 71 |
+
</div>
|
| 72 |
+
</div>
|
| 73 |
+
<div class="flex justify-between mt-4">
|
| 74 |
+
<button id="clear-btn" class="px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition flex items-center">
|
| 75 |
+
<i data-feather="trash-2" class="mr-2"></i>
|
| 76 |
+
Clear All
|
| 77 |
+
</button>
|
| 78 |
+
<button id="generate-btn" class="px-6 py-2 gradient-bg hover:opacity-90 rounded-lg transition flex items-center">
|
| 79 |
+
<i data-feather="zap" class="mr-2"></i>
|
| 80 |
+
Generate Embeddings
|
| 81 |
+
</button>
|
| 82 |
+
</div>
|
| 83 |
+
</div>
|
| 84 |
+
<!-- Output Section -->
|
| 85 |
+
<div class="embed-card rounded-xl p-6 shadow-lg">
|
| 86 |
+
<div class="flex items-center justify-between mb-4">
|
| 87 |
+
<div class="flex items-center">
|
| 88 |
+
<i data-feather="list" class="mr-2"></i>
|
| 89 |
+
<h2 class="text-xl font-semibold">Vector Embeddings</h2>
|
| 90 |
+
</div>
|
| 91 |
+
<button id="copy-btn" class="px-3 py-1 bg-gray-700 hover:bg-gray-600 rounded-lg transition flex items-center text-sm" disabled>
|
| 92 |
+
<i data-feather="copy" class="mr-1"></i>
|
| 93 |
+
Copy
|
| 94 |
+
</button>
|
| 95 |
+
</div>
|
| 96 |
+
<div id="output-container" class="vector-display bg-gray-800 bg-opacity-50 rounded-lg p-4 h-64 overflow-auto hidden">
|
| 97 |
+
<pre id="output-vector" class="text-sm"></pre>
|
| 98 |
+
</div>
|
| 99 |
+
<div id="placeholder" class="bg-gray-800 bg-opacity-30 rounded-lg p-8 text-center h-64 flex items-center justify-center">
|
| 100 |
+
<div class="opacity-60">
|
| 101 |
+
<i data-feather="wind" class="w-12 h-12 mx-auto mb-4"></i>
|
| 102 |
+
<p>Your embeddings will appear here</p>
|
| 103 |
+
</div>
|
| 104 |
+
</div>
|
| 105 |
+
<button id="download-btn" class="w-full mt-4 px-4 py-2 gradient-bg hover:opacity-90 rounded-lg transition flex items-center justify-center hidden">
|
| 106 |
+
<i data-feather="download" class="mr-2"></i>
|
| 107 |
+
Download as JSON
|
| 108 |
+
</button>
|
| 109 |
+
</div>
|
| 110 |
+
</div>
|
| 111 |
+
|
| 112 |
+
<!-- Info Section -->
|
| 113 |
+
<div class="embed-card rounded-xl p-6 mt-8 shadow-lg">
|
| 114 |
+
<div class="flex items-center mb-4">
|
| 115 |
+
<i data-feather="info" class="mr-2"></i>
|
| 116 |
+
<h2 class="text-xl font-semibold">About Text2Vector</h2>
|
| 117 |
+
</div>
|
| 118 |
+
<p class="mb-4">Text2Vector transforms your text into high-dimensional vector representations using Microsoft's multilingual-e5-large model, capturing semantic meaning across multiple languages.</p>
|
| 119 |
+
<div class="grid md:grid-cols-3 gap-4">
|
| 120 |
+
<div class="bg-gray-800 bg-opacity-30 p-4 rounded-lg">
|
| 121 |
+
<div class="flex items-center mb-2">
|
| 122 |
+
<i data-feather="hash" class="mr-2 text-purple-300"></i>
|
| 123 |
+
<h3 class="font-medium">Dimensionality</h3>
|
| 124 |
+
</div>
|
| 125 |
+
<p class="text-sm opacity-80">1024-dimensional vectors</p>
|
| 126 |
+
</div>
|
| 127 |
+
<div class="bg-gray-800 bg-opacity-30 p-4 rounded-lg">
|
| 128 |
+
<div class="flex items-center mb-2">
|
| 129 |
+
<i data-feather="cpu" class="mr-2 text-purple-300"></i>
|
| 130 |
+
<h3 class="font-medium">Model</h3>
|
| 131 |
+
</div>
|
| 132 |
+
<p class="text-sm opacity-80">Microsoft's multilingual-e5-large</p>
|
| 133 |
+
</div>
|
| 134 |
+
<div class="bg-gray-800 bg-opacity-30 p-4 rounded-lg">
|
| 135 |
+
<div class="flex items-center mb-2">
|
| 136 |
+
<i data-feather="code" class="mr-2 text-purple-300"></i>
|
| 137 |
+
<h3 class="font-medium">API</h3>
|
| 138 |
+
</div>
|
| 139 |
+
<p class="text-sm opacity-80">Simple REST integration</p>
|
| 140 |
+
</div>
|
| 141 |
+
</div>
|
| 142 |
+
</div>
|
| 143 |
+
</main>
|
| 144 |
+
</div>
|
| 145 |
+
<!-- Footer -->
|
| 146 |
+
<footer class="text-center py-8 opacity-70 text-sm">
|
| 147 |
+
<p>© 2023 Text2Vector ⚡ | Powered by multilingual-e5-large embeddings</p>
|
| 148 |
+
<p class="mt-2" id="api-status">API Status: <span class="text-red-500">Checking...</span></p>
|
| 149 |
+
</footer>
|
| 150 |
+
<!-- Scripts -->
|
| 151 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r121/three.min.js"></script>
|
| 152 |
+
<script src="https://cdn.jsdelivr.net/npm/vanta@latest/dist/vanta.globe.min.js"></script>
|
| 153 |
+
<script>
|
| 154 |
+
// Initialize Vanta.js background
|
| 155 |
+
VANTA.GLOBE({
|
| 156 |
+
el: "#vanta-bg",
|
| 157 |
+
mouseControls: true,
|
| 158 |
+
touchControls: true,
|
| 159 |
+
gyroControls: false,
|
| 160 |
+
minHeight: 200.00,
|
| 161 |
+
minWidth: 200.00,
|
| 162 |
+
scale: 1.00,
|
| 163 |
+
scaleMobile: 1.00,
|
| 164 |
+
color: 0x6e8efb,
|
| 165 |
+
backgroundColor: 0x0,
|
| 166 |
+
size: 0.8
|
| 167 |
+
});
|
| 168 |
+
|
| 169 |
+
// Initialize Feather Icons
|
| 170 |
+
feather.replace();
|
| 171 |
+
// DOM Elements
|
| 172 |
+
const textInputsContainer = document.getElementById('text-inputs-container');
|
| 173 |
+
const generateBtn = document.getElementById('generate-btn');
|
| 174 |
+
const clearBtn = document.getElementById('clear-btn');
|
| 175 |
+
const copyBtn = document.getElementById('copy-btn');
|
| 176 |
+
const downloadBtn = document.getElementById('download-btn');
|
| 177 |
+
const outputContainer = document.getElementById('output-container');
|
| 178 |
+
const outputVector = document.getElementById('output-vector');
|
| 179 |
+
const placeholder = document.getElementById('placeholder');
|
| 180 |
+
const addTextBtn = document.getElementById('add-text-btn');
|
| 181 |
+
|
| 182 |
+
// Add new text input field
|
| 183 |
+
addTextBtn.addEventListener('click', () => {
|
| 184 |
+
const newInputGroup = document.createElement('div');
|
| 185 |
+
newInputGroup.className = 'text-input-group mb-3 relative';
|
| 186 |
+
newInputGroup.innerHTML = `
|
| 187 |
+
<textarea class="w-full text-area bg-gray-800 bg-opacity-50 rounded-lg p-4 text-white border border-gray-600 focus:border-purple-400 focus:ring-1 focus:ring-purple-400 transition" placeholder="Enter your text here..."></textarea>
|
| 188 |
+
<button class="remove-text-btn absolute top-1 right-1 p-1 bg-gray-700 hover:bg-gray-600 rounded-full transition">
|
| 189 |
+
<i data-feather="x" class="w-4 h-4"></i>
|
| 190 |
+
</button>
|
| 191 |
+
`;
|
| 192 |
+
textInputsContainer.appendChild(newInputGroup);
|
| 193 |
+
feather.replace();
|
| 194 |
+
setupRemoveButtons();
|
| 195 |
+
});
|
| 196 |
+
|
| 197 |
+
// Setup remove buttons for all input fields
|
| 198 |
+
function setupRemoveButtons() {
|
| 199 |
+
document.querySelectorAll('.remove-text-btn').forEach(btn => {
|
| 200 |
+
btn.addEventListener('click', (e) => {
|
| 201 |
+
if (document.querySelectorAll('.text-input-group').length > 1) {
|
| 202 |
+
e.target.closest('.text-input-group').remove();
|
| 203 |
+
}
|
| 204 |
+
});
|
| 205 |
+
});
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
// Show remove buttons when hovering over input groups
|
| 209 |
+
document.addEventListener('mouseover', (e) => {
|
| 210 |
+
if (e.target.closest('.text-input-group')) {
|
| 211 |
+
const group = e.target.closest('.text-input-group');
|
| 212 |
+
if (document.querySelectorAll('.text-input-group').length > 1) {
|
| 213 |
+
group.querySelector('.remove-text-btn').style.display = 'block';
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
});
|
| 217 |
+
|
| 218 |
+
document.addEventListener('mouseout', (e) => {
|
| 219 |
+
if (e.target.closest('.text-input-group')) {
|
| 220 |
+
const group = e.target.closest('.text-input-group');
|
| 221 |
+
group.querySelector('.remove-text-btn').style.display = 'none';
|
| 222 |
+
}
|
| 223 |
+
});
|
| 224 |
+
|
| 225 |
+
setupRemoveButtons();
|
| 226 |
+
// API Configuration
|
| 227 |
+
const API_URL = '/embed';
|
| 228 |
+
// Check API health
|
| 229 |
+
async function checkApiHealth() {
|
| 230 |
+
try {
|
| 231 |
+
const response = await fetch('/health');
|
| 232 |
+
if (response.ok) {
|
| 233 |
+
document.getElementById('api-status').innerHTML =
|
| 234 |
+
'API Status: <span class="text-green-500">Online</span>';
|
| 235 |
+
} else {
|
| 236 |
+
throw new Error('API not responding');
|
| 237 |
+
}
|
| 238 |
+
} catch (error) {
|
| 239 |
+
document.getElementById('api-status').innerHTML =
|
| 240 |
+
'API Status: <span class="text-red-500">Offline</span>';
|
| 241 |
+
console.error('API health check failed:', error);
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
// Initial health check
|
| 246 |
+
checkApiHealth();
|
| 247 |
+
setInterval(checkApiHealth, 30000); // Check every 30 seconds
|
| 248 |
+
|
| 249 |
+
// Event Listeners
|
| 250 |
+
generateBtn.addEventListener('click', async () => {
|
| 251 |
+
const textInputs = Array.from(document.querySelectorAll('.text-input-group textarea'));
|
| 252 |
+
const texts = textInputs
|
| 253 |
+
.map(input => input.value.trim())
|
| 254 |
+
.filter(text => text.length > 0);
|
| 255 |
+
|
| 256 |
+
if (texts.length === 0) {
|
| 257 |
+
alert('Please enter at least one text input');
|
| 258 |
+
return;
|
| 259 |
+
}
|
| 260 |
+
// Show loading state
|
| 261 |
+
generateBtn.disabled = true;
|
| 262 |
+
generateBtn.innerHTML = '<i data-feather="loader" class="animate-spin mr-2"></i> Processing...';
|
| 263 |
+
feather.replace();
|
| 264 |
+
// Make API call to backend
|
| 265 |
+
try {
|
| 266 |
+
const response = await fetch('/embed', {
|
| 267 |
+
method: 'POST',
|
| 268 |
+
headers: {
|
| 269 |
+
'Content-Type': 'application/json',
|
| 270 |
+
},
|
| 271 |
+
body: JSON.stringify({
|
| 272 |
+
texts: texts
|
| 273 |
+
})
|
| 274 |
+
});
|
| 275 |
+
if (!response.ok) {
|
| 276 |
+
throw new Error(`API Error: ${response.status}`);
|
| 277 |
+
}
|
| 278 |
+
const data = await response.json();
|
| 279 |
+
const embeddings = texts.map((text, index) => ({
|
| 280 |
+
text: text,
|
| 281 |
+
vector: data.embeddings[index],
|
| 282 |
+
model: "multilingual-e5-large",
|
| 283 |
+
timestamp: new Date().toISOString()
|
| 284 |
+
}));
|
| 285 |
+
|
| 286 |
+
// Display the embeddings
|
| 287 |
+
outputVector.textContent = JSON.stringify(embeddings.length === 1 ? embeddings[0] : embeddings, null, 2);
|
| 288 |
+
placeholder.classList.add('hidden');
|
| 289 |
+
outputContainer.classList.remove('hidden');
|
| 290 |
+
copyBtn.disabled = false;
|
| 291 |
+
downloadBtn.classList.remove('hidden');
|
| 292 |
+
|
| 293 |
+
} catch (error) {
|
| 294 |
+
alert(`Failed to generate embeddings: ${error.message}`);
|
| 295 |
+
console.error(error);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
// Reset button
|
| 299 |
+
generateBtn.disabled = false;
|
| 300 |
+
generateBtn.innerHTML = '<i data-feather="zap" class="mr-2"></i> Generate Embeddings';
|
| 301 |
+
feather.replace();
|
| 302 |
+
});
|
| 303 |
+
|
| 304 |
+
clearBtn.addEventListener('click', () => {
|
| 305 |
+
document.querySelectorAll('.text-input-group textarea').forEach(input => {
|
| 306 |
+
input.value = '';
|
| 307 |
+
});
|
| 308 |
+
outputVector.textContent = '';
|
| 309 |
+
outputContainer.classList.add('hidden');
|
| 310 |
+
placeholder.classList.remove('hidden');
|
| 311 |
+
copyBtn.disabled = true;
|
| 312 |
+
downloadBtn.classList.add('hidden');
|
| 313 |
+
});
|
| 314 |
+
|
| 315 |
+
copyBtn.addEventListener('click', () => {
|
| 316 |
+
navigator.clipboard.writeText(outputVector.textContent)
|
| 317 |
+
.then(() => {
|
| 318 |
+
copyBtn.innerHTML = '<i data-feather="check" class="mr-1"></i> Copied!';
|
| 319 |
+
feather.replace();
|
| 320 |
+
setTimeout(() => {
|
| 321 |
+
copyBtn.innerHTML = '<i data-feather="copy" class="mr-1"></i> Copy';
|
| 322 |
+
feather.replace();
|
| 323 |
+
}, 2000);
|
| 324 |
+
});
|
| 325 |
+
});
|
| 326 |
+
|
| 327 |
+
downloadBtn.addEventListener('click', () => {
|
| 328 |
+
const blob = new Blob([outputVector.textContent], { type: 'application/json' });
|
| 329 |
+
const url = URL.createObjectURL(blob);
|
| 330 |
+
const a = document.createElement('a');
|
| 331 |
+
a.href = url;
|
| 332 |
+
a.download = `embedding-${new Date().getTime()}.json`;
|
| 333 |
+
document.body.appendChild(a);
|
| 334 |
+
a.click();
|
| 335 |
+
document.body.removeChild(a);
|
| 336 |
+
URL.revokeObjectURL(url);
|
| 337 |
+
});
|
| 338 |
+
</script>
|
| 339 |
+
</body>
|
| 340 |
+
</html>
|
pyproject.toml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "text2vector"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "API to call an embedding model"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
|
| 8 |
+
dependencies = [
|
| 9 |
+
"black>=25.9.0",
|
| 10 |
+
"fastapi>=0.119.0",
|
| 11 |
+
"httpx>=0.28.1",
|
| 12 |
+
"mypy>=1.18.2",
|
| 13 |
+
"pydantic>=2.12.0",
|
| 14 |
+
"pytest>=8.4.2",
|
| 15 |
+
"ruff>=0.14.0",
|
| 16 |
+
"torch>=2.8.0",
|
| 17 |
+
"transformers>=4.57.0",
|
| 18 |
+
"uvicorn>=0.37.0",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
# https://quantlane.com/blog/type-checking-large-codebase/
|
| 22 |
+
[tool.mypy]
|
| 23 |
+
# Ensure full coverage
|
| 24 |
+
disallow_untyped_calls = false
|
| 25 |
+
disallow_untyped_defs = true
|
| 26 |
+
disallow_incomplete_defs = true
|
| 27 |
+
disallow_untyped_decorators = false
|
| 28 |
+
check_untyped_defs = true
|
| 29 |
+
|
| 30 |
+
# Restrict dynamic typing
|
| 31 |
+
disallow_any_generics = false
|
| 32 |
+
disallow_subclassing_any = false
|
| 33 |
+
warn_return_any = false
|
| 34 |
+
|
| 35 |
+
# Know exactly what you're doing
|
| 36 |
+
warn_redundant_casts = true
|
| 37 |
+
warn_unused_ignores = false
|
| 38 |
+
warn_unused_configs = true
|
| 39 |
+
warn_unreachable = true
|
| 40 |
+
show_error_codes = true
|
| 41 |
+
|
| 42 |
+
# Explicit is better than implicit
|
| 43 |
+
no_implicit_optional = true
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi.testclient import TestClient
|
| 2 |
+
from app.main import app
|
| 3 |
+
|
| 4 |
+
client = TestClient(app)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_embed() -> None:
|
| 8 |
+
"""Test the /embed endpoint with valid input."""
|
| 9 |
+
response = client.post("/embed", json={"texts": ["query: Hello world"]})
|
| 10 |
+
assert response.status_code == 200 # OK
|
| 11 |
+
data = response.json()
|
| 12 |
+
assert "embeddings" in data
|
| 13 |
+
assert len(data["embeddings"][0]) == 1024
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_embed_no_texts() -> None:
|
| 17 |
+
"""Test the /embed endpoint with no texts provided."""
|
| 18 |
+
response = client.post("/embed", json={})
|
| 19 |
+
assert response.status_code == 422 # Unprocessable Entity
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_embed_long_text() -> None:
|
| 23 |
+
"""Test the /embed endpoint with a text longer than 2000 characters."""
|
| 24 |
+
long_text = "query: " + "a" * 1994 # 2001 characters
|
| 25 |
+
response = client.post("/embed", json={"texts": [long_text]})
|
| 26 |
+
assert response.status_code == 422 # Unprocessable Entity
|
tests/test_embeddings.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.embeddings import average_pool, embed_text
|
| 2 |
+
import torch
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_average_pool_basic() -> None:
|
| 7 |
+
"""Test average pooling produces correct shape and masking."""
|
| 8 |
+
last_hidden_states = torch.tensor(
|
| 9 |
+
[
|
| 10 |
+
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
| 11 |
+
[[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]],
|
| 12 |
+
]
|
| 13 |
+
) # shape: (2, 3, 2)
|
| 14 |
+
attention_mask = torch.tensor(
|
| 15 |
+
[
|
| 16 |
+
[1, 1, 0],
|
| 17 |
+
[1, 0, 0],
|
| 18 |
+
]
|
| 19 |
+
) # shape: (2, 3)
|
| 20 |
+
|
| 21 |
+
result = average_pool(last_hidden_states, attention_mask)
|
| 22 |
+
|
| 23 |
+
# Expected averages:
|
| 24 |
+
# row1: [(1+3)/2, (2+4)/2] = [2,3]
|
| 25 |
+
# row2: [10, 20]
|
| 26 |
+
expected = torch.tensor([[2.0, 3.0], [10.0, 20.0]])
|
| 27 |
+
|
| 28 |
+
assert torch.allclose(result, expected, atol=1e-6)
|
| 29 |
+
assert result.shape == (2, 2)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_embed_text_valid() -> None:
|
| 33 |
+
"""Test embedding returns correct number of vectors and dimensions."""
|
| 34 |
+
|
| 35 |
+
texts = ["query: Hello world", "query: Hej verden"]
|
| 36 |
+
embeddings = embed_text(texts)
|
| 37 |
+
|
| 38 |
+
# Assertions
|
| 39 |
+
assert isinstance(embeddings, list)
|
| 40 |
+
assert len(embeddings) == len(texts)
|
| 41 |
+
assert all(isinstance(vec, list) for vec in embeddings)
|
| 42 |
+
assert all(isinstance(x, float) for x in embeddings[0])
|
| 43 |
+
assert len(embeddings[0]) == 1024
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_embed_text_empty_list() -> None:
|
| 47 |
+
"""Should raise ValueError if no input texts."""
|
| 48 |
+
with pytest.raises(ValueError, match="No input texts provided"):
|
| 49 |
+
embed_text([])
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_embed_text_too_long() -> None:
|
| 53 |
+
"""Should raise ValueError for inputs exceeding 2000 characters."""
|
| 54 |
+
too_long = ["query: " + "a" * 1994] # 2001 characters
|
| 55 |
+
with pytest.raises(ValueError, match="exceed the maximum length"):
|
| 56 |
+
embed_text(too_long)
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|