jeanbaptdzd commited on
Commit
c77ec91
·
1 Parent(s): 58ff73c

Remove chat_service.py abstraction layer

Browse files

- Remove unnecessary chat_service.py pass-through layer
- Update router to call transformers_provider directly
- Update tests to mock provider functions instead of service layer
- Simplify architecture: Router → Provider → Model

app/routers/openai_api.py CHANGED
@@ -6,8 +6,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
6
 
7
  from app.config import settings
8
  from app.models.openai import ChatCompletionRequest
9
- from app.services import chat_service
10
- from app.providers.transformers_provider import initialize_model
11
 
12
  logger = logging.getLogger(__name__)
13
 
@@ -15,9 +14,9 @@ router = APIRouter()
15
 
16
 
17
  @router.get("/models")
18
- async def list_models():
19
  """List available models (OpenAI-compatible endpoint)"""
20
- return await chat_service.list_models()
21
 
22
 
23
  @router.post("/models/reload")
@@ -115,12 +114,12 @@ async def chat_completions(body: ChatCompletionRequest):
115
  logger.info(f"Chat completion request: model={payload['model']}, messages={len(payload['messages'])}, stream={payload['stream']}")
116
 
117
  if body.stream:
118
- stream = await chat_service.chat(payload, stream=True)
119
  # stream is already an AsyncIterator[str] with SSE-formatted chunks
120
  return StreamingResponse(stream, media_type="text/event-stream")
121
 
122
  # Non-streaming response
123
- data = await chat_service.chat(payload, stream=False)
124
  return JSONResponse(content=data)
125
 
126
  except ValueError as e:
 
6
 
7
  from app.config import settings
8
  from app.models.openai import ChatCompletionRequest
9
+ from app.providers.transformers_provider import initialize_model, chat, list_models
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
14
 
15
 
16
  @router.get("/models")
17
+ async def list_models_endpoint():
18
  """List available models (OpenAI-compatible endpoint)"""
19
+ return await list_models()
20
 
21
 
22
  @router.post("/models/reload")
 
114
  logger.info(f"Chat completion request: model={payload['model']}, messages={len(payload['messages'])}, stream={payload['stream']}")
115
 
116
  if body.stream:
117
+ stream = await chat(payload, stream=True)
118
  # stream is already an AsyncIterator[str] with SSE-formatted chunks
119
  return StreamingResponse(stream, media_type="text/event-stream")
120
 
121
  # Non-streaming response
122
+ data = await chat(payload, stream=False)
123
  return JSONResponse(content=data)
124
 
125
  except ValueError as e:
app/services/chat_service.py DELETED
@@ -1,33 +0,0 @@
1
- """Chat service layer providing abstraction over the provider."""
2
- from typing import Any, Dict, Union, AsyncIterator
3
-
4
- from app.providers import transformers_provider as provider
5
-
6
-
7
- async def list_models() -> Dict[str, Any]:
8
- """
9
- List available models.
10
-
11
- Returns:
12
- Dictionary containing model list in OpenAI-compatible format
13
- """
14
- return await provider.list_models()
15
-
16
-
17
- async def chat(
18
- payload: Dict[str, Any],
19
- stream: bool = False
20
- ) -> Union[Dict[str, Any], AsyncIterator[str]]:
21
- """
22
- Process chat completion request.
23
-
24
- Args:
25
- payload: Request payload containing messages and generation parameters
26
- stream: Whether to stream the response
27
-
28
- Returns:
29
- Response dictionary or async iterator for streaming
30
- """
31
- return await provider.chat(payload, stream=stream)
32
-
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_openai_routes.py CHANGED
@@ -10,9 +10,9 @@ def test_models(monkeypatch):
10
  async def fake_list_models():
11
  return {"data": [{"id": "DragonLLM/LLM-Pro-Finance-Small"}]}
12
 
13
- from app.services import chat_service
14
 
15
- monkeypatch.setattr(chat_service, "list_models", fake_list_models)
16
 
17
  r = client.get("/v1/models")
18
  assert r.status_code == 200
@@ -37,9 +37,9 @@ def test_chat_completions(monkeypatch):
37
  ],
38
  }
39
 
40
- from app.services import chat_service
41
 
42
- monkeypatch.setattr(chat_service, "chat", fake_chat)
43
 
44
  r = client.post(
45
  "/v1/chat/completions",
 
10
  async def fake_list_models():
11
  return {"data": [{"id": "DragonLLM/LLM-Pro-Finance-Small"}]}
12
 
13
+ from app.providers import transformers_provider
14
 
15
+ monkeypatch.setattr(transformers_provider, "list_models", fake_list_models)
16
 
17
  r = client.get("/v1/models")
18
  assert r.status_code == 200
 
37
  ],
38
  }
39
 
40
+ from app.providers import transformers_provider
41
 
42
+ monkeypatch.setattr(transformers_provider, "chat", fake_chat)
43
 
44
  r = client.post(
45
  "/v1/chat/completions",