Spaces:
Sleeping
Sleeping
Commit
·
8f842e4
0
Parent(s):
Init FastAPI python project
Browse files- .github/workflows/deploy.yml +51 -0
- .gitignore +11 -0
- .python-version +1 -0
- Dockerfile +15 -0
- Makefile +15 -0
- README.md +141 -0
- app/__init__.py +0 -0
- app/embeddings.py +34 -0
- app/logger.py +5 -0
- app/main.py +28 -0
- app/models.py +42 -0
- pyproject.toml +19 -0
- tests/__init__.py +0 -0
- tests/test_api.py +26 -0
- tests/test_embeddings.py +56 -0
- uv.lock +0 -0
.github/workflows/deploy.yml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: test-and-deploy
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
workflow_dispatch:
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
build-and-test:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
|
| 13 |
+
steps:
|
| 14 |
+
- name: Checkout repository
|
| 15 |
+
uses: actions/checkout@v5
|
| 16 |
+
|
| 17 |
+
- name: Install uv
|
| 18 |
+
uses: astral-sh/setup-uv@v6
|
| 19 |
+
with:
|
| 20 |
+
version: "0.9.2"
|
| 21 |
+
|
| 22 |
+
- name: Set up Python
|
| 23 |
+
run: uv python install
|
| 24 |
+
|
| 25 |
+
- name: Install dependencies
|
| 26 |
+
run: uv sync --locked --all-extras --dev
|
| 27 |
+
|
| 28 |
+
- name: Run linting
|
| 29 |
+
run: make lint
|
| 30 |
+
|
| 31 |
+
- name: Run tests
|
| 32 |
+
run: make test
|
| 33 |
+
|
| 34 |
+
deploy-to-hf:
|
| 35 |
+
needs: build-and-test
|
| 36 |
+
runs-on: ubuntu-latest
|
| 37 |
+
|
| 38 |
+
steps:
|
| 39 |
+
- name: Checkout repository
|
| 40 |
+
uses: actions/checkout@v5
|
| 41 |
+
with:
|
| 42 |
+
fetch-depth: 0
|
| 43 |
+
lfs: true
|
| 44 |
+
|
| 45 |
+
- name: Deploy to Hugging Face Space
|
| 46 |
+
env:
|
| 47 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 48 |
+
HF_USERNAME: ${{ secrets.HF_USERNAME }}
|
| 49 |
+
HF_SPACE: ${{ secrets.HF_SPACE }}
|
| 50 |
+
run: |
|
| 51 |
+
git push --force https://$HF_USERNAME:$HF_TOKEN@huggingface.co/spaces/$HF_USERNAME/$HF_SPACE main
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
.DS_Store
|
| 9 |
+
|
| 10 |
+
# Virtual environments
|
| 11 |
+
.venv
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
Dockerfile
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
RUN useradd -m -u 1000 user
|
| 4 |
+
USER user
|
| 5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
COPY --chown=user pyproject.toml ./
|
| 10 |
+
RUN pip install --no-cache-dir .
|
| 11 |
+
|
| 12 |
+
COPY --chown=user app ./app
|
| 13 |
+
|
| 14 |
+
# Start the app with Uvicorn
|
| 15 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
Makefile
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
APP_DIR := $(CURDIR)/app
|
| 2 |
+
TESTS_DIR := $(CURDIR)/tests
|
| 3 |
+
|
| 4 |
+
format:
|
| 5 |
+
uv run black $(APP_DIR) $(TESTS_DIR)/*.py
|
| 6 |
+
uv run ruff check $(APP_DIR) $(TESTS_DIR) --fix
|
| 7 |
+
|
| 8 |
+
lint:
|
| 9 |
+
uv run black --check $(APP_DIR) $(TESTS_DIR)/*.py
|
| 10 |
+
uv run ruff check $(APP_DIR) $(TESTS_DIR)
|
| 11 |
+
uv run mypy $(APP_DIR) $(TESTS_DIR)
|
| 12 |
+
|
| 13 |
+
test:
|
| 14 |
+
uv run pytest ${TESTS_DIR}
|
| 15 |
+
|
README.md
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Text Embedding
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Embedding API
|
| 12 |
+
|
| 13 |
+
API to call an embedding model ([intfloat/multilingual-e5-large](https://huggingface.co/intfloat/multilingual-e5-large)) for generating multilingual text embeddings.<br>
|
| 14 |
+
The embedding model takes a text string and converts it into 1024 dimension vector.<br>
|
| 15 |
+
Using a `POST` request to the `/embed` endpoint with a list of texts, the API returns their corresponding embeddings.<br>
|
| 16 |
+
A maximum of 2000 characters per text is enforced to avoid truncation, and thereby loss of information, by the tokenizer.<br>
|
| 17 |
+
Each text must start with either "query: " or "passage: ".<br>
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
The API is deployed at a Hugging Face Docker space where the Swagger UI can be acccessed at:<br>
|
| 21 |
+
[https://emilbm-text-embedding.hf.space/docs](https://emilbm-text-embedding.hf.space/docs)
|
| 22 |
+
|
| 23 |
+
## Features
|
| 24 |
+
|
| 25 |
+
- FastAPI-based REST API
|
| 26 |
+
- `/embed` endpoint for generating embeddings from a list of texts
|
| 27 |
+
- `/health` endpoint for checking the API status
|
| 28 |
+
- Uses HuggingFace Transformers and PyTorch
|
| 29 |
+
- Includes linting and unit tests
|
| 30 |
+
- Dockerfile for containerization
|
| 31 |
+
- CI/CD with GitHub Actions to build, lint, test, and deploy to Hugging Face
|
| 32 |
+
|
| 33 |
+
## Local Development
|
| 34 |
+
### Requirements
|
| 35 |
+
|
| 36 |
+
- Python 3.12+
|
| 37 |
+
- [UV](https://docs.astral.sh/uv/)
|
| 38 |
+
- (Optional) Docker
|
| 39 |
+
### Installation
|
| 40 |
+
|
| 41 |
+
1. **Clone the repository:**
|
| 42 |
+
```sh
|
| 43 |
+
git clone <your-repo-url>
|
| 44 |
+
cd embedding-api
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
2. **Create a virtual environment and activate it:**
|
| 48 |
+
```sh
|
| 49 |
+
uv venv
|
| 50 |
+
source .venv/bin/activate
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
3. **Install dependencies:**
|
| 54 |
+
```sh
|
| 55 |
+
uv sync
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Formatting, Linting and Unit Tests
|
| 59 |
+
- **Formatting (with Black and Ruff) and linting (with Black, Ruff, and MyPy):**
|
| 60 |
+
```sh
|
| 61 |
+
make format
|
| 62 |
+
make lint
|
| 63 |
+
```
|
| 64 |
+
- **Run unit tests:**
|
| 65 |
+
```sh
|
| 66 |
+
make test
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Running Locally (without Docker)
|
| 70 |
+
|
| 71 |
+
Start the API server with Uvicorn:
|
| 72 |
+
|
| 73 |
+
```sh
|
| 74 |
+
uvicorn app.main:app --reload --port 7860
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Running Locally (with Docker)
|
| 78 |
+
Build and start the API server with Docker:
|
| 79 |
+
|
| 80 |
+
```sh
|
| 81 |
+
docker build -t embedding-api .
|
| 82 |
+
docker run -p 7860:7860 embedding-api
|
| 83 |
+
```
|
| 84 |
+
### Test the endpoint
|
| 85 |
+
Test the endpoint with either:
|
| 86 |
+
```sh
|
| 87 |
+
curl -X 'POST' \
|
| 88 |
+
'http://127.0.0.1:7860/embed' \
|
| 89 |
+
-H 'accept: application/json' \
|
| 90 |
+
-H 'Content-Type: application/json' \
|
| 91 |
+
-d '{
|
| 92 |
+
"texts": [
|
| 93 |
+
"query: what is the capital of France?",
|
| 94 |
+
"passage: Paris is the capital of France."
|
| 95 |
+
]
|
| 96 |
+
}'
|
| 97 |
+
```
|
| 98 |
+
Or through the Swagger UI.
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
## Usage
|
| 103 |
+
|
| 104 |
+
### Embed Endpoint
|
| 105 |
+
|
| 106 |
+
- **POST** `/embed`
|
| 107 |
+
- **Request Body:**
|
| 108 |
+
```json
|
| 109 |
+
{
|
| 110 |
+
"texts": ["Hello world", "Hej verden"]
|
| 111 |
+
}
|
| 112 |
+
```
|
| 113 |
+
- **Response:**
|
| 114 |
+
```json
|
| 115 |
+
{
|
| 116 |
+
"embeddings": [[...], [...]]
|
| 117 |
+
}
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### Health Endpoint
|
| 121 |
+
|
| 122 |
+
- **GET** `/health`
|
| 123 |
+
- **Response:**
|
| 124 |
+
```json
|
| 125 |
+
{
|
| 126 |
+
"status": "ok"
|
| 127 |
+
}
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
## Project Structure
|
| 131 |
+
|
| 132 |
+
```
|
| 133 |
+
app/
|
| 134 |
+
main.py # FastAPI app
|
| 135 |
+
embeddings.py # Embedding logic
|
| 136 |
+
models.py # Request/response models
|
| 137 |
+
logger.py # Logging setup
|
| 138 |
+
tests/
|
| 139 |
+
test_api.py # API tests
|
| 140 |
+
test_embeddings.py # Embedding tests
|
| 141 |
+
```
|
app/__init__.py
ADDED
|
File without changes
|
app/embeddings.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModel
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
|
| 4 |
+
model = AutoModel.from_pretrained("intfloat/multilingual-e5-large")
|
| 5 |
+
tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-large")
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
| 9 |
+
"""Average pool the token embeddings."""
|
| 10 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 11 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def embed_text(texts: list[str]) -> list[list[float]]:
|
| 15 |
+
"""
|
| 16 |
+
Generate embeddings for a list of texts.
|
| 17 |
+
|
| 18 |
+
The model supports a maximum of 512 tokens per input which typically corresponds to about 2000-2500 characters.
|
| 19 |
+
To avoid losing important information, we set a limit of 2000 characters per input text.
|
| 20 |
+
"""
|
| 21 |
+
if not texts:
|
| 22 |
+
raise ValueError("No input texts provided.")
|
| 23 |
+
if any(len(text) > 2000 for text in texts):
|
| 24 |
+
raise ValueError(
|
| 25 |
+
"One or more input texts exceed the maximum length of 2000 characters."
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
batch_dict = tokenizer(
|
| 29 |
+
texts, max_length=512, padding=True, truncation=True, return_tensors="pt"
|
| 30 |
+
)
|
| 31 |
+
outputs = model(**batch_dict)
|
| 32 |
+
embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
|
| 33 |
+
|
| 34 |
+
return embeddings.detach().cpu().tolist()
|
app/logger.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
logging.basicConfig(
|
| 4 |
+
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
| 5 |
+
)
|
app/main.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
+
from app.models import EmbedRequest, EmbedResponse
|
| 3 |
+
from app.embeddings import embed_text
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
app = FastAPI(
|
| 7 |
+
title="Embedding API",
|
| 8 |
+
description="A simple API to generate text embeddings using Microsoft's `multilingual-e5-large` model.",
|
| 9 |
+
version="1.0.0",
|
| 10 |
+
)
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@app.post("/embed", response_model=EmbedResponse)
|
| 15 |
+
async def embed(request: EmbedRequest):
|
| 16 |
+
"""Generate embeddings for a list of texts."""
|
| 17 |
+
try:
|
| 18 |
+
vectors = embed_text(request.texts)
|
| 19 |
+
return {"embeddings": vectors}
|
| 20 |
+
except Exception as e:
|
| 21 |
+
logger.exception("Error generating embeddings")
|
| 22 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@app.get("/health")
|
| 26 |
+
async def health_check():
|
| 27 |
+
"""Health check endpoint."""
|
| 28 |
+
return {"status": "ok"}
|
app/models.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field, field_validator, StringConstraints
|
| 2 |
+
from typing import Annotated
|
| 3 |
+
|
| 4 |
+
PREFIX_ACCEPTED = ["query: ", "passage: "]
|
| 5 |
+
|
| 6 |
+
ShortText = Annotated[str, StringConstraints(max_length=2000)]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EmbedRequest(BaseModel):
|
| 10 |
+
"""
|
| 11 |
+
Request model for texts to be embedded.
|
| 12 |
+
Each text must start with an accepted prefix and be ≤ 2000 characters.
|
| 13 |
+
The texts need to start with either "query: " or "passage: ".
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
texts: list[ShortText] = Field(
|
| 17 |
+
...,
|
| 18 |
+
json_schema_extra={
|
| 19 |
+
"example": [
|
| 20 |
+
"query: what is the capital of France?",
|
| 21 |
+
"passage: Paris is the capital of France.",
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
description="List of texts to be embedded (≤ 2000 characters each) and must start with 'query: ' or 'passage: '.",
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
@field_validator("texts")
|
| 28 |
+
@classmethod
|
| 29 |
+
def check_prefixes(cls, texts: list[str]) -> list[str]:
|
| 30 |
+
for t in texts:
|
| 31 |
+
if not any(t.startswith(prefix) for prefix in PREFIX_ACCEPTED):
|
| 32 |
+
raise ValueError(f"Each text must start with one of {PREFIX_ACCEPTED}")
|
| 33 |
+
return texts
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class EmbedResponse(BaseModel):
|
| 37 |
+
"""Response model containing embeddings."""
|
| 38 |
+
|
| 39 |
+
embeddings: list[list[float]] = Field(
|
| 40 |
+
...,
|
| 41 |
+
description="List of embedding vectors corresponding to the input texts. Each embedding is a list of floats with length 1024.",
|
| 42 |
+
)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "embedding-api"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "API to call an embedding model"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
|
| 8 |
+
dependencies = [
|
| 9 |
+
"black>=25.9.0",
|
| 10 |
+
"fastapi>=0.119.0",
|
| 11 |
+
"httpx>=0.28.1",
|
| 12 |
+
"mypy>=1.18.2",
|
| 13 |
+
"pydantic>=2.12.0",
|
| 14 |
+
"pytest>=8.4.2",
|
| 15 |
+
"ruff>=0.14.0",
|
| 16 |
+
"torch>=2.8.0",
|
| 17 |
+
"transformers>=4.57.0",
|
| 18 |
+
"uvicorn>=0.37.0",
|
| 19 |
+
]
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi.testclient import TestClient
|
| 2 |
+
from app.main import app
|
| 3 |
+
|
| 4 |
+
client = TestClient(app)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_embed():
|
| 8 |
+
"""Test the /embed endpoint with valid input."""
|
| 9 |
+
response = client.post("/embed", json={"texts": ["query: Hello world"]})
|
| 10 |
+
assert response.status_code == 200 # OK
|
| 11 |
+
data = response.json()
|
| 12 |
+
assert "embeddings" in data
|
| 13 |
+
assert len(data["embeddings"][0]) == 1024
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_embed_no_texts():
|
| 17 |
+
"""Test the /embed endpoint with no texts provided."""
|
| 18 |
+
response = client.post("/embed", json={})
|
| 19 |
+
assert response.status_code == 422 # Unprocessable Entity
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_embed_long_text():
|
| 23 |
+
"""Test the /embed endpoint with a text longer than 2000 characters."""
|
| 24 |
+
long_text = "query: " + "a" * 1994 # 2001 characters
|
| 25 |
+
response = client.post("/embed", json={"texts": [long_text]})
|
| 26 |
+
assert response.status_code == 422 # Unprocessable Entity
|
tests/test_embeddings.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.embeddings import average_pool, embed_text
|
| 2 |
+
import torch
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_average_pool_basic():
|
| 7 |
+
"""Test average pooling produces correct shape and masking."""
|
| 8 |
+
last_hidden_states = torch.tensor(
|
| 9 |
+
[
|
| 10 |
+
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
| 11 |
+
[[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]],
|
| 12 |
+
]
|
| 13 |
+
) # shape: (2, 3, 2)
|
| 14 |
+
attention_mask = torch.tensor(
|
| 15 |
+
[
|
| 16 |
+
[1, 1, 0],
|
| 17 |
+
[1, 0, 0],
|
| 18 |
+
]
|
| 19 |
+
) # shape: (2, 3)
|
| 20 |
+
|
| 21 |
+
result = average_pool(last_hidden_states, attention_mask)
|
| 22 |
+
|
| 23 |
+
# Expected averages:
|
| 24 |
+
# row1: [(1+3)/2, (2+4)/2] = [2,3]
|
| 25 |
+
# row2: [10, 20]
|
| 26 |
+
expected = torch.tensor([[2.0, 3.0], [10.0, 20.0]])
|
| 27 |
+
|
| 28 |
+
assert torch.allclose(result, expected, atol=1e-6)
|
| 29 |
+
assert result.shape == (2, 2)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_embed_text_valid():
|
| 33 |
+
"""Test embedding returns correct number of vectors and dimensions."""
|
| 34 |
+
|
| 35 |
+
texts = ["query: Hello world", "query: Hej verden"]
|
| 36 |
+
embeddings = embed_text(texts)
|
| 37 |
+
|
| 38 |
+
# Assertions
|
| 39 |
+
assert isinstance(embeddings, list)
|
| 40 |
+
assert len(embeddings) == len(texts)
|
| 41 |
+
assert all(isinstance(vec, list) for vec in embeddings)
|
| 42 |
+
assert all(isinstance(x, float) for x in embeddings[0])
|
| 43 |
+
assert len(embeddings[0]) == 1024
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_embed_text_empty_list():
|
| 47 |
+
"""Should raise ValueError if no input texts."""
|
| 48 |
+
with pytest.raises(ValueError, match="No input texts provided"):
|
| 49 |
+
embed_text([])
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_embed_text_too_long():
|
| 53 |
+
"""Should raise ValueError for inputs exceeding 2000 characters."""
|
| 54 |
+
too_long = ["query: " + "a" * 1994] # 2001 characters
|
| 55 |
+
with pytest.raises(ValueError, match="exceed the maximum length"):
|
| 56 |
+
embed_text(too_long)
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|