Spaces:
Sleeping
Sleeping
Added mypy rules and applied them.
Browse filesCreated logger and showcased basic return of info.
- app/embeddings.py +4 -0
- app/logger.py +34 -3
- app/main.py +5 -4
- pyproject.toml +24 -0
- tests/test_api.py +3 -3
- tests/test_embeddings.py +4 -4
app/embeddings.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from transformers import AutoTokenizer, AutoModel
|
| 2 |
from torch import Tensor
|
|
|
|
| 3 |
|
| 4 |
model = AutoModel.from_pretrained("intfloat/multilingual-e5-large")
|
| 5 |
tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-large")
|
|
@@ -28,6 +29,9 @@ def embed_text(texts: list[str]) -> list[list[float]]:
|
|
| 28 |
batch_dict = tokenizer(
|
| 29 |
texts, max_length=512, padding=True, truncation=True, return_tensors="pt"
|
| 30 |
)
|
|
|
|
|
|
|
|
|
|
| 31 |
outputs = model(**batch_dict)
|
| 32 |
embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
|
| 33 |
|
|
|
|
| 1 |
from transformers import AutoTokenizer, AutoModel
|
| 2 |
from torch import Tensor
|
| 3 |
+
from app.logger import logger
|
| 4 |
|
| 5 |
model = AutoModel.from_pretrained("intfloat/multilingual-e5-large")
|
| 6 |
tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-large")
|
|
|
|
| 29 |
batch_dict = tokenizer(
|
| 30 |
texts, max_length=512, padding=True, truncation=True, return_tensors="pt"
|
| 31 |
)
|
| 32 |
+
logger.info(
|
| 33 |
+
f"Tokenized {len(texts)} texts with number of tokens per text: {batch_dict['input_ids'].ne(tokenizer.pad_token_id).sum(dim=1).tolist()}"
|
| 34 |
+
)
|
| 35 |
outputs = model(**batch_dict)
|
| 36 |
embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
|
| 37 |
|
app/logger.py
CHANGED
|
@@ -1,5 +1,36 @@
|
|
| 1 |
import logging
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from logging.config import dictConfig
|
| 3 |
|
| 4 |
+
LOGGING_CONFIG = {
|
| 5 |
+
"version": 1,
|
| 6 |
+
"disable_existing_loggers": False,
|
| 7 |
+
"formatters": {
|
| 8 |
+
"default": {
|
| 9 |
+
"format": "[%(asctime)s] [%(levelname)s] %(name)s: %(message)s",
|
| 10 |
+
"datefmt": "%Y-%m-%d %H:%M:%S",
|
| 11 |
+
},
|
| 12 |
+
"json": {
|
| 13 |
+
"format": (
|
| 14 |
+
'{"time": "%(asctime)s", '
|
| 15 |
+
'"level": "%(levelname)s", '
|
| 16 |
+
'"name": "%(name)s", '
|
| 17 |
+
'"message": "%(message)s"}'
|
| 18 |
+
),
|
| 19 |
+
"datefmt": "%Y-%m-%d %H:%M:%S",
|
| 20 |
+
},
|
| 21 |
+
},
|
| 22 |
+
"handlers": {
|
| 23 |
+
"console": {
|
| 24 |
+
"class": "logging.StreamHandler",
|
| 25 |
+
"formatter": "default",
|
| 26 |
+
},
|
| 27 |
+
},
|
| 28 |
+
"root": {
|
| 29 |
+
"level": "INFO",
|
| 30 |
+
"handlers": ["console"],
|
| 31 |
+
},
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
dictConfig(LOGGING_CONFIG)
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger("app")
|
app/main.py
CHANGED
|
@@ -1,21 +1,22 @@
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from app.models import EmbedRequest, EmbedResponse
|
| 3 |
from app.embeddings import embed_text
|
| 4 |
-
import
|
| 5 |
|
| 6 |
app = FastAPI(
|
| 7 |
title="Embedding API",
|
| 8 |
description="A simple API to generate text embeddings using Microsoft's `multilingual-e5-large` model.",
|
| 9 |
version="1.0.0",
|
| 10 |
)
|
| 11 |
-
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
@app.post("/embed", response_model=EmbedResponse)
|
| 15 |
-
async def embed(request: EmbedRequest):
|
| 16 |
"""Generate embeddings for a list of texts."""
|
|
|
|
| 17 |
try:
|
| 18 |
vectors = embed_text(request.texts)
|
|
|
|
| 19 |
return {"embeddings": vectors}
|
| 20 |
except Exception as e:
|
| 21 |
logger.exception("Error generating embeddings")
|
|
@@ -23,6 +24,6 @@ async def embed(request: EmbedRequest):
|
|
| 23 |
|
| 24 |
|
| 25 |
@app.get("/health")
|
| 26 |
-
async def health_check():
|
| 27 |
"""Health check endpoint."""
|
| 28 |
return {"status": "ok"}
|
|
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from app.models import EmbedRequest, EmbedResponse
|
| 3 |
from app.embeddings import embed_text
|
| 4 |
+
from app.logger import logger
|
| 5 |
|
| 6 |
app = FastAPI(
|
| 7 |
title="Embedding API",
|
| 8 |
description="A simple API to generate text embeddings using Microsoft's `multilingual-e5-large` model.",
|
| 9 |
version="1.0.0",
|
| 10 |
)
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
@app.post("/embed", response_model=EmbedResponse)
|
| 14 |
+
async def embed(request: EmbedRequest) -> dict[str, list[list[float]]]:
|
| 15 |
"""Generate embeddings for a list of texts."""
|
| 16 |
+
logger.info("Generating embeddings...")
|
| 17 |
try:
|
| 18 |
vectors = embed_text(request.texts)
|
| 19 |
+
logger.info("Embeddings generated successfully!")
|
| 20 |
return {"embeddings": vectors}
|
| 21 |
except Exception as e:
|
| 22 |
logger.exception("Error generating embeddings")
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
@app.get("/health")
|
| 27 |
+
async def health_check() -> dict[str, str]:
|
| 28 |
"""Health check endpoint."""
|
| 29 |
return {"status": "ok"}
|
pyproject.toml
CHANGED
|
@@ -17,3 +17,27 @@ dependencies = [
|
|
| 17 |
"transformers>=4.57.0",
|
| 18 |
"uvicorn>=0.37.0",
|
| 19 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"transformers>=4.57.0",
|
| 18 |
"uvicorn>=0.37.0",
|
| 19 |
]
|
| 20 |
+
|
| 21 |
+
# https://quantlane.com/blog/type-checking-large-codebase/
|
| 22 |
+
[tool.mypy]
|
| 23 |
+
# Ensure full coverage
|
| 24 |
+
disallow_untyped_calls = false
|
| 25 |
+
disallow_untyped_defs = true
|
| 26 |
+
disallow_incomplete_defs = true
|
| 27 |
+
disallow_untyped_decorators = false
|
| 28 |
+
check_untyped_defs = true
|
| 29 |
+
|
| 30 |
+
# Restrict dynamic typing
|
| 31 |
+
disallow_any_generics = false
|
| 32 |
+
disallow_subclassing_any = false
|
| 33 |
+
warn_return_any = false
|
| 34 |
+
|
| 35 |
+
# Know exactly what you're doing
|
| 36 |
+
warn_redundant_casts = true
|
| 37 |
+
warn_unused_ignores = false
|
| 38 |
+
warn_unused_configs = true
|
| 39 |
+
warn_unreachable = true
|
| 40 |
+
show_error_codes = true
|
| 41 |
+
|
| 42 |
+
# Explicit is better than implicit
|
| 43 |
+
no_implicit_optional = true
|
tests/test_api.py
CHANGED
|
@@ -4,7 +4,7 @@ from app.main import app
|
|
| 4 |
client = TestClient(app)
|
| 5 |
|
| 6 |
|
| 7 |
-
def test_embed():
|
| 8 |
"""Test the /embed endpoint with valid input."""
|
| 9 |
response = client.post("/embed", json={"texts": ["query: Hello world"]})
|
| 10 |
assert response.status_code == 200 # OK
|
|
@@ -13,13 +13,13 @@ def test_embed():
|
|
| 13 |
assert len(data["embeddings"][0]) == 1024
|
| 14 |
|
| 15 |
|
| 16 |
-
def test_embed_no_texts():
|
| 17 |
"""Test the /embed endpoint with no texts provided."""
|
| 18 |
response = client.post("/embed", json={})
|
| 19 |
assert response.status_code == 422 # Unprocessable Entity
|
| 20 |
|
| 21 |
|
| 22 |
-
def test_embed_long_text():
|
| 23 |
"""Test the /embed endpoint with a text longer than 2000 characters."""
|
| 24 |
long_text = "query: " + "a" * 1994 # 2001 characters
|
| 25 |
response = client.post("/embed", json={"texts": [long_text]})
|
|
|
|
| 4 |
client = TestClient(app)
|
| 5 |
|
| 6 |
|
| 7 |
+
def test_embed() -> None:
|
| 8 |
"""Test the /embed endpoint with valid input."""
|
| 9 |
response = client.post("/embed", json={"texts": ["query: Hello world"]})
|
| 10 |
assert response.status_code == 200 # OK
|
|
|
|
| 13 |
assert len(data["embeddings"][0]) == 1024
|
| 14 |
|
| 15 |
|
| 16 |
+
def test_embed_no_texts() -> None:
|
| 17 |
"""Test the /embed endpoint with no texts provided."""
|
| 18 |
response = client.post("/embed", json={})
|
| 19 |
assert response.status_code == 422 # Unprocessable Entity
|
| 20 |
|
| 21 |
|
| 22 |
+
def test_embed_long_text() -> None:
|
| 23 |
"""Test the /embed endpoint with a text longer than 2000 characters."""
|
| 24 |
long_text = "query: " + "a" * 1994 # 2001 characters
|
| 25 |
response = client.post("/embed", json={"texts": [long_text]})
|
tests/test_embeddings.py
CHANGED
|
@@ -3,7 +3,7 @@ import torch
|
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
|
| 6 |
-
def test_average_pool_basic():
|
| 7 |
"""Test average pooling produces correct shape and masking."""
|
| 8 |
last_hidden_states = torch.tensor(
|
| 9 |
[
|
|
@@ -29,7 +29,7 @@ def test_average_pool_basic():
|
|
| 29 |
assert result.shape == (2, 2)
|
| 30 |
|
| 31 |
|
| 32 |
-
def test_embed_text_valid():
|
| 33 |
"""Test embedding returns correct number of vectors and dimensions."""
|
| 34 |
|
| 35 |
texts = ["query: Hello world", "query: Hej verden"]
|
|
@@ -43,13 +43,13 @@ def test_embed_text_valid():
|
|
| 43 |
assert len(embeddings[0]) == 1024
|
| 44 |
|
| 45 |
|
| 46 |
-
def test_embed_text_empty_list():
|
| 47 |
"""Should raise ValueError if no input texts."""
|
| 48 |
with pytest.raises(ValueError, match="No input texts provided"):
|
| 49 |
embed_text([])
|
| 50 |
|
| 51 |
|
| 52 |
-
def test_embed_text_too_long():
|
| 53 |
"""Should raise ValueError for inputs exceeding 2000 characters."""
|
| 54 |
too_long = ["query: " + "a" * 1994] # 2001 characters
|
| 55 |
with pytest.raises(ValueError, match="exceed the maximum length"):
|
|
|
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
|
| 6 |
+
def test_average_pool_basic() -> None:
|
| 7 |
"""Test average pooling produces correct shape and masking."""
|
| 8 |
last_hidden_states = torch.tensor(
|
| 9 |
[
|
|
|
|
| 29 |
assert result.shape == (2, 2)
|
| 30 |
|
| 31 |
|
| 32 |
+
def test_embed_text_valid() -> None:
|
| 33 |
"""Test embedding returns correct number of vectors and dimensions."""
|
| 34 |
|
| 35 |
texts = ["query: Hello world", "query: Hej verden"]
|
|
|
|
| 43 |
assert len(embeddings[0]) == 1024
|
| 44 |
|
| 45 |
|
| 46 |
+
def test_embed_text_empty_list() -> None:
|
| 47 |
"""Should raise ValueError if no input texts."""
|
| 48 |
with pytest.raises(ValueError, match="No input texts provided"):
|
| 49 |
embed_text([])
|
| 50 |
|
| 51 |
|
| 52 |
+
def test_embed_text_too_long() -> None:
|
| 53 |
"""Should raise ValueError for inputs exceeding 2000 characters."""
|
| 54 |
too_long = ["query: " + "a" * 1994] # 2001 characters
|
| 55 |
with pytest.raises(ValueError, match="exceed the maximum length"):
|