jeanbaptdzd commited on
Commit
192844a
·
2 Parent(s): d4fd4e1 6a4421a

Merge feat/tool-enabling into master - resolve conflicts

Browse files

- Keep tool calls implementation from feat/tool-enabling
- Keep latest changes from origin/master for main.py and middleware.py

.coderabbit.yaml CHANGED
@@ -16,7 +16,6 @@ review:
16
  simple: false # Set to true for faster, simpler reviews
17
  high_level_summary: true
18
  estimate_time: true
19
- project_language: python
20
 
21
  chat:
22
  enabled: true
 
16
  simple: false # Set to true for faster, simpler reviews
17
  high_level_summary: true
18
  estimate_time: true
 
19
 
20
  chat:
21
  enabled: true
app/config.py CHANGED
@@ -1,11 +1,33 @@
 
 
 
 
1
  from pydantic_settings import BaseSettings, SettingsConfigDict
2
 
3
 
4
  class Settings(BaseSettings):
5
- model: str = "DragonLLM/qwen3-8b-fin-v1.0"
6
- service_api_key: str | None = None
7
- log_level: str = "info"
8
- force_model_reload: bool = False # Set FORCE_MODEL_RELOAD=true to bypass cache on startup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  model_config = SettingsConfigDict(
11
  env_file=".env",
 
1
+ """Application configuration using Pydantic settings."""
2
+
3
+ from typing import Literal
4
+ from pydantic import Field
5
  from pydantic_settings import BaseSettings, SettingsConfigDict
6
 
7
 
8
  class Settings(BaseSettings):
9
+ """Application settings loaded from environment variables.
10
+
11
+ Supports loading from .env file with UTF-8 encoding.
12
+ All settings can be overridden via environment variables.
13
+ """
14
+
15
+ model: str = Field(
16
+ default="DragonLLM/qwen3-8b-fin-v1.0",
17
+ description="Hugging Face model identifier"
18
+ )
19
+ service_api_key: str | None = Field(
20
+ default=None,
21
+ description="Optional API key for authentication (SERVICE_API_KEY env var)"
22
+ )
23
+ log_level: Literal["debug", "info", "warning", "error"] = Field(
24
+ default="info",
25
+ description="Logging level"
26
+ )
27
+ force_model_reload: bool = Field(
28
+ default=False,
29
+ description="Force model reload from Hugging Face, bypassing cache (FORCE_MODEL_RELOAD env var)"
30
+ )
31
 
32
  model_config = SettingsConfigDict(
33
  env_file=".env",
app/main.py CHANGED
@@ -1,30 +1,39 @@
1
- from typing import Dict, Any
 
 
 
 
 
2
  from fastapi import FastAPI
 
 
3
  from app.middleware import api_key_guard
4
- from app.middleware.rate_limit import rate_limit_middleware
5
  from app.routers import openai_api
6
- from app.config import settings
7
- from app.providers.transformers_provider import model, _initialized
8
- from app.utils.stats import get_stats_tracker
9
- import logging
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
- app = FastAPI(title="LLM Pro Finance API (Transformers)")
 
 
 
 
16
 
17
  # Mount routers
18
  app.include_router(openai_api.router, prefix="/v1")
19
 
20
- # Middleware order: rate limiting first, then API key guard
21
- app.middleware("http")(rate_limit_middleware)
22
  app.middleware("http")(api_key_guard)
23
 
 
24
  @app.on_event("startup")
25
- async def startup_event():
26
- """Startup event - initialize model in background"""
27
- import threading
 
 
 
28
  logger.info("Starting LLM Pro Finance API...")
29
 
30
  force_reload = settings.force_model_reload
@@ -33,7 +42,8 @@ async def startup_event():
33
 
34
  logger.info("Initializing model in background thread...")
35
 
36
- def load_model():
 
37
  from app.providers.transformers_provider import initialize_model
38
  initialize_model(force_reload=force_reload)
39
 
@@ -42,32 +52,30 @@ async def startup_event():
42
  thread.start()
43
  logger.info("Model initialization started in background")
44
 
 
45
  @app.get("/")
46
  async def root() -> Dict[str, str]:
47
- """Root endpoint returning API status and information."""
 
 
 
 
48
  return {
49
  "status": "ok",
50
  "service": "Qwen Open Finance R 8B Inference",
51
  "version": "1.0.0",
52
- "model": "DragonLLM/qwen3-8b-fin-v1.0",
53
  "backend": "Transformers"
54
  }
55
 
56
- @app.get("/health")
57
- async def health() -> Dict[str, Any]:
58
- """Health check endpoint with model readiness status."""
59
- model_ready = _initialized and model is not None
60
- return {
61
- "status": "healthy" if model_ready else "initializing",
62
- "service": "LLM Pro Finance API",
63
- "model_ready": model_ready,
64
- }
65
-
66
 
67
- @app.get("/v1/stats")
68
- async def get_stats() -> Dict[str, Any]:
69
- """Get API usage statistics and token counts."""
70
- stats_tracker = get_stats_tracker()
71
- return stats_tracker.get_stats()
 
 
 
72
 
73
 
 
1
+ """Main FastAPI application entry point."""
2
+
3
+ import logging
4
+ import threading
5
+ from typing import Dict
6
+
7
  from fastapi import FastAPI
8
+
9
+ from app.config import settings
10
  from app.middleware import api_key_guard
 
11
  from app.routers import openai_api
 
 
 
 
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
+ app = FastAPI(
18
+ title="LLM Pro Finance API (Transformers)",
19
+ description="OpenAI-compatible API for financial LLM inference",
20
+ version="1.0.0"
21
+ )
22
 
23
  # Mount routers
24
  app.include_router(openai_api.router, prefix="/v1")
25
 
26
+ # Optional API key middleware
 
27
  app.middleware("http")(api_key_guard)
28
 
29
+
30
  @app.on_event("startup")
31
+ async def startup_event() -> None:
32
+ """Startup event - initialize model in background thread.
33
+
34
+ Loads the model asynchronously to avoid blocking the API startup.
35
+ Model loading happens in a daemon thread so it doesn't prevent shutdown.
36
+ """
37
  logger.info("Starting LLM Pro Finance API...")
38
 
39
  force_reload = settings.force_model_reload
 
42
 
43
  logger.info("Initializing model in background thread...")
44
 
45
+ def load_model() -> None:
46
+ """Load the model in a background thread."""
47
  from app.providers.transformers_provider import initialize_model
48
  initialize_model(force_reload=force_reload)
49
 
 
52
  thread.start()
53
  logger.info("Model initialization started in background")
54
 
55
+
56
  @app.get("/")
57
  async def root() -> Dict[str, str]:
58
+ """Root endpoint returning API status and information.
59
+
60
+ Returns:
61
+ Dictionary containing API status, service name, version, model, and backend.
62
+ """
63
  return {
64
  "status": "ok",
65
  "service": "Qwen Open Finance R 8B Inference",
66
  "version": "1.0.0",
67
+ "model": settings.model,
68
  "backend": "Transformers"
69
  }
70
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ @app.get("/health")
73
+ async def health() -> Dict[str, str]:
74
+ """Health check endpoint for monitoring and load balancers.
75
+
76
+ Returns:
77
+ Dictionary with service health status.
78
+ """
79
+ return {"status": "healthy", "service": "LLM Pro Finance API"}
80
 
81
 
app/middleware.py CHANGED
@@ -1,26 +1,46 @@
1
- from fastapi import Request, HTTPException
2
- from fastapi.responses import JSONResponse
 
3
 
4
  from app.config import settings
5
 
 
 
6
 
7
- async def api_key_guard(request: Request, call_next):
8
- # Public endpoints that don't require authentication
9
- public_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json", "/v1/stats"]
 
 
 
 
 
10
 
 
 
 
11
  # Skip auth for public endpoints
12
- if request.url.path in public_paths:
13
  return await call_next(request)
14
 
15
  # Skip auth if no API key is configured
16
  if not settings.service_api_key:
17
  return await call_next(request)
18
 
19
- # Check API key
20
- key = request.headers.get("x-api-key") or request.headers.get("authorization")
21
- if key and key.replace("Bearer ", "").strip() == settings.service_api_key:
 
 
 
 
 
 
22
  return await call_next(request)
23
 
24
- return JSONResponse({"error": "unauthorized"}, status_code=401)
 
 
 
25
 
26
 
 
1
+ from fastapi import Request
2
+ from fastapi.responses import JSONResponse, Response
3
+ from typing import Callable, Awaitable, Union
4
 
5
  from app.config import settings
6
 
7
+ # Public endpoints that don't require authentication
8
+ PUBLIC_PATHS = frozenset(["/", "/health", "/docs", "/redoc", "/openapi.json"])
9
 
10
+
11
+ async def api_key_guard(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Union[Response, JSONResponse]:
12
+ """
13
+ Middleware to protect API endpoints with optional API key authentication.
14
+
15
+ Args:
16
+ request: FastAPI request object
17
+ call_next: Next middleware/handler in the chain
18
 
19
+ Returns:
20
+ Response from next handler or 401 if unauthorized
21
+ """
22
  # Skip auth for public endpoints
23
+ if request.url.path in PUBLIC_PATHS:
24
  return await call_next(request)
25
 
26
  # Skip auth if no API key is configured
27
  if not settings.service_api_key:
28
  return await call_next(request)
29
 
30
+ # Check API key from headers
31
+ api_key = request.headers.get("x-api-key")
32
+ if not api_key:
33
+ # Also check Authorization header with Bearer token
34
+ auth_header = request.headers.get("authorization", "")
35
+ if auth_header.startswith("Bearer "):
36
+ api_key = auth_header.replace("Bearer ", "").strip()
37
+
38
+ if api_key and api_key == settings.service_api_key:
39
  return await call_next(request)
40
 
41
+ return JSONResponse(
42
+ content={"error": {"message": "unauthorized", "type": "authentication_error"}},
43
+ status_code=401
44
+ )
45
 
46
 
app/providers/base.py CHANGED
@@ -1,11 +1,33 @@
1
- from typing import Protocol, Dict, Any
 
 
2
 
3
 
4
  class LLMProvider(Protocol):
 
 
 
 
 
 
5
  async def list_models(self) -> Dict[str, Any]:
 
 
 
 
 
6
  ...
7
-
8
  async def chat(self, payload: Dict[str, Any], stream: bool = False) -> Any:
 
 
 
 
 
 
 
 
 
9
  ...
10
 
11
 
 
1
+ """Base protocol for LLM providers."""
2
+
3
+ from typing import Any, Dict, Protocol
4
 
5
 
6
  class LLMProvider(Protocol):
7
+ """Protocol defining the interface for LLM providers.
8
+
9
+ Any class implementing this protocol must provide async methods
10
+ for listing models and generating chat completions.
11
+ """
12
+
13
  async def list_models(self) -> Dict[str, Any]:
14
+ """List available models.
15
+
16
+ Returns:
17
+ Dictionary containing model information.
18
+ """
19
  ...
20
+
21
  async def chat(self, payload: Dict[str, Any], stream: bool = False) -> Any:
22
+ """Generate chat completion.
23
+
24
+ Args:
25
+ payload: Request payload containing messages and parameters
26
+ stream: Whether to stream the response
27
+
28
+ Returns:
29
+ Chat completion response (varies by implementation)
30
+ """
31
  ...
32
 
33
 
app/services/chat_service.py CHANGED
@@ -1,13 +1,33 @@
1
- from typing import Any, Dict
 
2
 
3
  from app.providers import transformers_provider as provider
4
 
5
 
6
  async def list_models() -> Dict[str, Any]:
 
 
 
 
 
 
7
  return await provider.list_models()
8
 
9
 
10
- async def chat(payload: Dict[str, Any], stream: bool = False):
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  return await provider.chat(payload, stream=stream)
12
 
13
 
 
1
+ """Chat service layer providing abstraction over the provider."""
2
+ from typing import Any, Dict, Union, AsyncIterator
3
 
4
  from app.providers import transformers_provider as provider
5
 
6
 
7
  async def list_models() -> Dict[str, Any]:
8
+ """
9
+ List available models.
10
+
11
+ Returns:
12
+ Dictionary containing model list in OpenAI-compatible format
13
+ """
14
  return await provider.list_models()
15
 
16
 
17
+ async def chat(
18
+ payload: Dict[str, Any],
19
+ stream: bool = False
20
+ ) -> Union[Dict[str, Any], AsyncIterator[str]]:
21
+ """
22
+ Process chat completion request.
23
+
24
+ Args:
25
+ payload: Request payload containing messages and generation parameters
26
+ stream: Whether to stream the response
27
+
28
+ Returns:
29
+ Response dictionary or async iterator for streaming
30
+ """
31
  return await provider.chat(payload, stream=stream)
32
 
33
 
app/utils/constants.py CHANGED
@@ -1,18 +1,25 @@
1
- """Application-wide constants."""
2
 
3
  import os
 
 
4
 
5
  # Model configuration
6
- MODEL_NAME = "DragonLLM/qwen3-8b-fin-v1.0"
7
 
8
  # Cache directory - respect HF_HOME if set, otherwise use default
9
- CACHE_DIR = os.getenv("HF_HOME", "/tmp/huggingface")
10
 
11
  # Hugging Face token environment variable priority order
12
- HF_TOKEN_VARS = ["HF_TOKEN_LC2", "HF_TOKEN_LC", "HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"]
 
 
 
 
 
13
 
14
  # French language detection patterns
15
- FRENCH_PHRASES = [
16
  "en français",
17
  "répondez en français",
18
  "réponse française",
@@ -20,9 +27,11 @@ FRENCH_PHRASES = [
20
  "expliquez en français",
21
  ]
22
 
23
- FRENCH_CHARS = ["é", "è", "ê", "à", "ç", "ù", "ô", "î", "â", "û", "ë", "ï"]
 
 
24
 
25
- FRENCH_PATTERNS = [
26
  "qu'est-ce",
27
  "qu'est",
28
  "expliquez",
@@ -38,7 +47,7 @@ FRENCH_PATTERNS = [
38
  "définissez",
39
  ]
40
 
41
- FRENCH_SYSTEM_PROMPT = (
42
  "Vous êtes un assistant financier expert. "
43
  "Répondez TOUJOURS en français. "
44
  "Soyez concis et précis dans vos explications. "
@@ -46,15 +55,15 @@ FRENCH_SYSTEM_PROMPT = (
46
  )
47
 
48
  # Qwen3 EOS tokens
49
- EOS_TOKENS = [151645, 151643] # [<|im_end|>, <|endoftext|>]
50
- PAD_TOKEN_ID = 151643 # <|endoftext|>
51
 
52
  # Generation defaults
53
- DEFAULT_MAX_TOKENS = 1000 # Increased for complete answers with concise reasoning
54
- DEFAULT_TEMPERATURE = 0.7
55
- DEFAULT_TOP_P = 1.0
56
- DEFAULT_TOP_K = 20
57
- REPETITION_PENALTY = 1.05
58
 
59
  # Model initialization constants
60
  MODEL_INIT_TIMEOUT_SECONDS = 300 # 5 minutes timeout for model initialization
 
1
+ """Application-wide constants and configuration."""
2
 
3
  import os
4
+ from typing import Final, List
5
+
6
 
7
  # Model configuration
8
+ MODEL_NAME: Final[str] = "DragonLLM/qwen3-8b-fin-v1.0"
9
 
10
  # Cache directory - respect HF_HOME if set, otherwise use default
11
+ CACHE_DIR: Final[str] = os.getenv("HF_HOME", "/tmp/huggingface")
12
 
13
  # Hugging Face token environment variable priority order
14
+ HF_TOKEN_VARS: Final[List[str]] = [
15
+ "HF_TOKEN_LC2",
16
+ "HF_TOKEN_LC",
17
+ "HF_TOKEN",
18
+ "HUGGING_FACE_HUB_TOKEN"
19
+ ]
20
 
21
  # French language detection patterns
22
+ FRENCH_PHRASES: Final[List[str]] = [
23
  "en français",
24
  "répondez en français",
25
  "réponse française",
 
27
  "expliquez en français",
28
  ]
29
 
30
+ FRENCH_CHARS: Final[List[str]] = [
31
+ "é", "è", "ê", "à", "ç", "ù", "ô", "î", "â", "û", "ë", "ï"
32
+ ]
33
 
34
+ FRENCH_PATTERNS: Final[List[str]] = [
35
  "qu'est-ce",
36
  "qu'est",
37
  "expliquez",
 
47
  "définissez",
48
  ]
49
 
50
+ FRENCH_SYSTEM_PROMPT: Final[str] = (
51
  "Vous êtes un assistant financier expert. "
52
  "Répondez TOUJOURS en français. "
53
  "Soyez concis et précis dans vos explications. "
 
55
  )
56
 
57
  # Qwen3 EOS tokens
58
+ EOS_TOKENS: Final[List[int]] = [151645, 151643] # [<|im_end|>, <|endoftext|>]
59
+ PAD_TOKEN_ID: Final[int] = 151643 # <|endoftext|>
60
 
61
  # Generation defaults
62
+ DEFAULT_MAX_TOKENS: Final[int] = 1000 # Increased for complete answers with concise reasoning
63
+ DEFAULT_TEMPERATURE: Final[float] = 0.7
64
+ DEFAULT_TOP_P: Final[float] = 1.0
65
+ DEFAULT_TOP_K: Final[int] = 20
66
+ REPETITION_PENALTY: Final[float] = 1.05
67
 
68
  # Model initialization constants
69
  MODEL_INIT_TIMEOUT_SECONDS = 300 # 5 minutes timeout for model initialization
app/utils/helpers.py CHANGED
@@ -2,7 +2,7 @@
2
 
3
  import os
4
  import logging
5
- from typing import Optional, Tuple
6
 
7
  from app.utils.constants import HF_TOKEN_VARS, FRENCH_PHRASES, FRENCH_CHARS, FRENCH_PATTERNS
8
 
@@ -24,7 +24,7 @@ def get_hf_token() -> Tuple[Optional[str], str]:
24
  return None, "none"
25
 
26
 
27
- def is_french_request(messages: list) -> bool:
28
  """
29
  Detect if the request is in French based on user messages.
30
 
@@ -55,7 +55,7 @@ def is_french_request(messages: list) -> bool:
55
  return False
56
 
57
 
58
- def has_french_system_prompt(messages: list) -> bool:
59
  """Check if messages already contain a French system prompt."""
60
  return any(
61
  "français" in msg.get("content", "").lower()
 
2
 
3
  import os
4
  import logging
5
+ from typing import Optional, Tuple, List, Dict, Any
6
 
7
  from app.utils.constants import HF_TOKEN_VARS, FRENCH_PHRASES, FRENCH_CHARS, FRENCH_PATTERNS
8
 
 
24
  return None, "none"
25
 
26
 
27
+ def is_french_request(messages: List[Dict[str, Any]]) -> bool:
28
  """
29
  Detect if the request is in French based on user messages.
30
 
 
55
  return False
56
 
57
 
58
+ def has_french_system_prompt(messages: List[Dict[str, Any]]) -> bool:
59
  """Check if messages already contain a French system prompt."""
60
  return any(
61
  "français" in msg.get("content", "").lower()
app/utils/memory.py CHANGED
@@ -1,12 +1,23 @@
1
  """GPU memory management utilities."""
2
 
3
  import gc
 
 
4
  import torch
5
- from typing import Optional
6
 
7
 
8
- def clear_gpu_memory(model=None, tokenizer=None):
9
- """Clear GPU memory completely."""
 
 
 
 
 
 
 
 
 
 
10
  if not torch.cuda.is_available():
11
  return
12
 
 
1
  """GPU memory management utilities."""
2
 
3
  import gc
4
+ from typing import Optional, Any
5
+
6
  import torch
 
7
 
8
 
9
+ def clear_gpu_memory(model: Optional[Any] = None, tokenizer: Optional[Any] = None) -> None:
10
+ """Clear GPU memory completely.
11
+
12
+ This function performs aggressive GPU memory cleanup by:
13
+ 1. Deleting model and tokenizer objects if provided
14
+ 2. Clearing CUDA cache
15
+ 3. Running multiple garbage collection passes
16
+
17
+ Args:
18
+ model: Optional model object to delete
19
+ tokenizer: Optional tokenizer object to delete
20
+ """
21
  if not torch.cuda.is_available():
22
  return
23