emilbm commited on
Commit
c4d1eaa
·
1 Parent(s): 8ff7bad

Added mypy rules and applied them.

Browse files

Created logger and showcased basic return of info.

app/embeddings.py CHANGED
@@ -1,5 +1,6 @@
1
  from transformers import AutoTokenizer, AutoModel
2
  from torch import Tensor
 
3
 
4
  model = AutoModel.from_pretrained("intfloat/multilingual-e5-large")
5
  tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-large")
@@ -28,6 +29,9 @@ def embed_text(texts: list[str]) -> list[list[float]]:
28
  batch_dict = tokenizer(
29
  texts, max_length=512, padding=True, truncation=True, return_tensors="pt"
30
  )
 
 
 
31
  outputs = model(**batch_dict)
32
  embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
33
 
 
1
  from transformers import AutoTokenizer, AutoModel
2
  from torch import Tensor
3
+ from app.logger import logger
4
 
5
  model = AutoModel.from_pretrained("intfloat/multilingual-e5-large")
6
  tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-large")
 
29
  batch_dict = tokenizer(
30
  texts, max_length=512, padding=True, truncation=True, return_tensors="pt"
31
  )
32
+ logger.info(
33
+ f"Tokenized {len(texts)} texts with number of tokens per text: {batch_dict['input_ids'].ne(tokenizer.pad_token_id).sum(dim=1).tolist()}"
34
+ )
35
  outputs = model(**batch_dict)
36
  embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
37
 
app/logger.py CHANGED
@@ -1,5 +1,36 @@
1
  import logging
 
2
 
3
- logging.basicConfig(
4
- level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
5
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from logging.config import dictConfig
3
 
4
+ LOGGING_CONFIG = {
5
+ "version": 1,
6
+ "disable_existing_loggers": False,
7
+ "formatters": {
8
+ "default": {
9
+ "format": "[%(asctime)s] [%(levelname)s] %(name)s: %(message)s",
10
+ "datefmt": "%Y-%m-%d %H:%M:%S",
11
+ },
12
+ "json": {
13
+ "format": (
14
+ '{"time": "%(asctime)s", '
15
+ '"level": "%(levelname)s", '
16
+ '"name": "%(name)s", '
17
+ '"message": "%(message)s"}'
18
+ ),
19
+ "datefmt": "%Y-%m-%d %H:%M:%S",
20
+ },
21
+ },
22
+ "handlers": {
23
+ "console": {
24
+ "class": "logging.StreamHandler",
25
+ "formatter": "default",
26
+ },
27
+ },
28
+ "root": {
29
+ "level": "INFO",
30
+ "handlers": ["console"],
31
+ },
32
+ }
33
+
34
+ dictConfig(LOGGING_CONFIG)
35
+
36
+ logger = logging.getLogger("app")
app/main.py CHANGED
@@ -1,21 +1,22 @@
1
  from fastapi import FastAPI, HTTPException
2
  from app.models import EmbedRequest, EmbedResponse
3
  from app.embeddings import embed_text
4
- import logging
5
 
6
  app = FastAPI(
7
  title="Embedding API",
8
  description="A simple API to generate text embeddings using Microsoft's `multilingual-e5-large` model.",
9
  version="1.0.0",
10
  )
11
- logger = logging.getLogger(__name__)
12
 
13
 
14
  @app.post("/embed", response_model=EmbedResponse)
15
- async def embed(request: EmbedRequest):
16
  """Generate embeddings for a list of texts."""
 
17
  try:
18
  vectors = embed_text(request.texts)
 
19
  return {"embeddings": vectors}
20
  except Exception as e:
21
  logger.exception("Error generating embeddings")
@@ -23,6 +24,6 @@ async def embed(request: EmbedRequest):
23
 
24
 
25
  @app.get("/health")
26
- async def health_check():
27
  """Health check endpoint."""
28
  return {"status": "ok"}
 
1
  from fastapi import FastAPI, HTTPException
2
  from app.models import EmbedRequest, EmbedResponse
3
  from app.embeddings import embed_text
4
+ from app.logger import logger
5
 
6
  app = FastAPI(
7
  title="Embedding API",
8
  description="A simple API to generate text embeddings using Microsoft's `multilingual-e5-large` model.",
9
  version="1.0.0",
10
  )
 
11
 
12
 
13
  @app.post("/embed", response_model=EmbedResponse)
14
+ async def embed(request: EmbedRequest) -> dict[str, list[list[float]]]:
15
  """Generate embeddings for a list of texts."""
16
+ logger.info("Generating embeddings...")
17
  try:
18
  vectors = embed_text(request.texts)
19
+ logger.info("Embeddings generated successfully!")
20
  return {"embeddings": vectors}
21
  except Exception as e:
22
  logger.exception("Error generating embeddings")
 
24
 
25
 
26
  @app.get("/health")
27
+ async def health_check() -> dict[str, str]:
28
  """Health check endpoint."""
29
  return {"status": "ok"}
pyproject.toml CHANGED
@@ -17,3 +17,27 @@ dependencies = [
17
  "transformers>=4.57.0",
18
  "uvicorn>=0.37.0",
19
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  "transformers>=4.57.0",
18
  "uvicorn>=0.37.0",
19
  ]
20
+
21
+ # https://quantlane.com/blog/type-checking-large-codebase/
22
+ [tool.mypy]
23
+ # Ensure full coverage
24
+ disallow_untyped_calls = false
25
+ disallow_untyped_defs = true
26
+ disallow_incomplete_defs = true
27
+ disallow_untyped_decorators = false
28
+ check_untyped_defs = true
29
+
30
+ # Restrict dynamic typing
31
+ disallow_any_generics = false
32
+ disallow_subclassing_any = false
33
+ warn_return_any = false
34
+
35
+ # Know exactly what you're doing
36
+ warn_redundant_casts = true
37
+ warn_unused_ignores = false
38
+ warn_unused_configs = true
39
+ warn_unreachable = true
40
+ show_error_codes = true
41
+
42
+ # Explicit is better than implicit
43
+ no_implicit_optional = true
tests/test_api.py CHANGED
@@ -4,7 +4,7 @@ from app.main import app
4
  client = TestClient(app)
5
 
6
 
7
- def test_embed():
8
  """Test the /embed endpoint with valid input."""
9
  response = client.post("/embed", json={"texts": ["query: Hello world"]})
10
  assert response.status_code == 200 # OK
@@ -13,13 +13,13 @@ def test_embed():
13
  assert len(data["embeddings"][0]) == 1024
14
 
15
 
16
- def test_embed_no_texts():
17
  """Test the /embed endpoint with no texts provided."""
18
  response = client.post("/embed", json={})
19
  assert response.status_code == 422 # Unprocessable Entity
20
 
21
 
22
- def test_embed_long_text():
23
  """Test the /embed endpoint with a text longer than 2000 characters."""
24
  long_text = "query: " + "a" * 1994 # 2001 characters
25
  response = client.post("/embed", json={"texts": [long_text]})
 
4
  client = TestClient(app)
5
 
6
 
7
+ def test_embed() -> None:
8
  """Test the /embed endpoint with valid input."""
9
  response = client.post("/embed", json={"texts": ["query: Hello world"]})
10
  assert response.status_code == 200 # OK
 
13
  assert len(data["embeddings"][0]) == 1024
14
 
15
 
16
+ def test_embed_no_texts() -> None:
17
  """Test the /embed endpoint with no texts provided."""
18
  response = client.post("/embed", json={})
19
  assert response.status_code == 422 # Unprocessable Entity
20
 
21
 
22
+ def test_embed_long_text() -> None:
23
  """Test the /embed endpoint with a text longer than 2000 characters."""
24
  long_text = "query: " + "a" * 1994 # 2001 characters
25
  response = client.post("/embed", json={"texts": [long_text]})
tests/test_embeddings.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import pytest
4
 
5
 
6
- def test_average_pool_basic():
7
  """Test average pooling produces correct shape and masking."""
8
  last_hidden_states = torch.tensor(
9
  [
@@ -29,7 +29,7 @@ def test_average_pool_basic():
29
  assert result.shape == (2, 2)
30
 
31
 
32
- def test_embed_text_valid():
33
  """Test embedding returns correct number of vectors and dimensions."""
34
 
35
  texts = ["query: Hello world", "query: Hej verden"]
@@ -43,13 +43,13 @@ def test_embed_text_valid():
43
  assert len(embeddings[0]) == 1024
44
 
45
 
46
- def test_embed_text_empty_list():
47
  """Should raise ValueError if no input texts."""
48
  with pytest.raises(ValueError, match="No input texts provided"):
49
  embed_text([])
50
 
51
 
52
- def test_embed_text_too_long():
53
  """Should raise ValueError for inputs exceeding 2000 characters."""
54
  too_long = ["query: " + "a" * 1994] # 2001 characters
55
  with pytest.raises(ValueError, match="exceed the maximum length"):
 
3
  import pytest
4
 
5
 
6
+ def test_average_pool_basic() -> None:
7
  """Test average pooling produces correct shape and masking."""
8
  last_hidden_states = torch.tensor(
9
  [
 
29
  assert result.shape == (2, 2)
30
 
31
 
32
+ def test_embed_text_valid() -> None:
33
  """Test embedding returns correct number of vectors and dimensions."""
34
 
35
  texts = ["query: Hello world", "query: Hej verden"]
 
43
  assert len(embeddings[0]) == 1024
44
 
45
 
46
+ def test_embed_text_empty_list() -> None:
47
  """Should raise ValueError if no input texts."""
48
  with pytest.raises(ValueError, match="No input texts provided"):
49
  embed_text([])
50
 
51
 
52
+ def test_embed_text_too_long() -> None:
53
  """Should raise ValueError for inputs exceeding 2000 characters."""
54
  too_long = ["query: " + "a" * 1994] # 2001 characters
55
  with pytest.raises(ValueError, match="exceed the maximum length"):