Merge feat/tool-enabling into master - resolve conflicts
Browse files- Keep tool calls implementation from feat/tool-enabling
- Keep latest changes from origin/master for main.py and middleware.py
- .coderabbit.yaml +0 -1
- app/config.py +26 -4
- app/main.py +38 -30
- app/middleware.py +30 -10
- app/providers/base.py +24 -2
- app/services/chat_service.py +22 -2
- app/utils/constants.py +24 -15
- app/utils/helpers.py +3 -3
- app/utils/memory.py +14 -3
.coderabbit.yaml
CHANGED
|
@@ -16,7 +16,6 @@ review:
|
|
| 16 |
simple: false # Set to true for faster, simpler reviews
|
| 17 |
high_level_summary: true
|
| 18 |
estimate_time: true
|
| 19 |
-
project_language: python
|
| 20 |
|
| 21 |
chat:
|
| 22 |
enabled: true
|
|
|
|
| 16 |
simple: false # Set to true for faster, simpler reviews
|
| 17 |
high_level_summary: true
|
| 18 |
estimate_time: true
|
|
|
|
| 19 |
|
| 20 |
chat:
|
| 21 |
enabled: true
|
app/config.py
CHANGED
|
@@ -1,11 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 2 |
|
| 3 |
|
| 4 |
class Settings(BaseSettings):
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
model_config = SettingsConfigDict(
|
| 11 |
env_file=".env",
|
|
|
|
| 1 |
+
"""Application configuration using Pydantic settings."""
|
| 2 |
+
|
| 3 |
+
from typing import Literal
|
| 4 |
+
from pydantic import Field
|
| 5 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 6 |
|
| 7 |
|
| 8 |
class Settings(BaseSettings):
|
| 9 |
+
"""Application settings loaded from environment variables.
|
| 10 |
+
|
| 11 |
+
Supports loading from .env file with UTF-8 encoding.
|
| 12 |
+
All settings can be overridden via environment variables.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
model: str = Field(
|
| 16 |
+
default="DragonLLM/qwen3-8b-fin-v1.0",
|
| 17 |
+
description="Hugging Face model identifier"
|
| 18 |
+
)
|
| 19 |
+
service_api_key: str | None = Field(
|
| 20 |
+
default=None,
|
| 21 |
+
description="Optional API key for authentication (SERVICE_API_KEY env var)"
|
| 22 |
+
)
|
| 23 |
+
log_level: Literal["debug", "info", "warning", "error"] = Field(
|
| 24 |
+
default="info",
|
| 25 |
+
description="Logging level"
|
| 26 |
+
)
|
| 27 |
+
force_model_reload: bool = Field(
|
| 28 |
+
default=False,
|
| 29 |
+
description="Force model reload from Hugging Face, bypassing cache (FORCE_MODEL_RELOAD env var)"
|
| 30 |
+
)
|
| 31 |
|
| 32 |
model_config = SettingsConfigDict(
|
| 33 |
env_file=".env",
|
app/main.py
CHANGED
|
@@ -1,30 +1,39 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from fastapi import FastAPI
|
|
|
|
|
|
|
| 3 |
from app.middleware import api_key_guard
|
| 4 |
-
from app.middleware.rate_limit import rate_limit_middleware
|
| 5 |
from app.routers import openai_api
|
| 6 |
-
from app.config import settings
|
| 7 |
-
from app.providers.transformers_provider import model, _initialized
|
| 8 |
-
from app.utils.stats import get_stats_tracker
|
| 9 |
-
import logging
|
| 10 |
|
| 11 |
# Configure logging
|
| 12 |
logging.basicConfig(level=logging.INFO)
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
-
app = FastAPI(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# Mount routers
|
| 18 |
app.include_router(openai_api.router, prefix="/v1")
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
app.middleware("http")(rate_limit_middleware)
|
| 22 |
app.middleware("http")(api_key_guard)
|
| 23 |
|
|
|
|
| 24 |
@app.on_event("startup")
|
| 25 |
-
async def startup_event():
|
| 26 |
-
"""Startup event - initialize model in background
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
logger.info("Starting LLM Pro Finance API...")
|
| 29 |
|
| 30 |
force_reload = settings.force_model_reload
|
|
@@ -33,7 +42,8 @@ async def startup_event():
|
|
| 33 |
|
| 34 |
logger.info("Initializing model in background thread...")
|
| 35 |
|
| 36 |
-
def load_model():
|
|
|
|
| 37 |
from app.providers.transformers_provider import initialize_model
|
| 38 |
initialize_model(force_reload=force_reload)
|
| 39 |
|
|
@@ -42,32 +52,30 @@ async def startup_event():
|
|
| 42 |
thread.start()
|
| 43 |
logger.info("Model initialization started in background")
|
| 44 |
|
|
|
|
| 45 |
@app.get("/")
|
| 46 |
async def root() -> Dict[str, str]:
|
| 47 |
-
"""Root endpoint returning API status and information.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
return {
|
| 49 |
"status": "ok",
|
| 50 |
"service": "Qwen Open Finance R 8B Inference",
|
| 51 |
"version": "1.0.0",
|
| 52 |
-
"model":
|
| 53 |
"backend": "Transformers"
|
| 54 |
}
|
| 55 |
|
| 56 |
-
@app.get("/health")
|
| 57 |
-
async def health() -> Dict[str, Any]:
|
| 58 |
-
"""Health check endpoint with model readiness status."""
|
| 59 |
-
model_ready = _initialized and model is not None
|
| 60 |
-
return {
|
| 61 |
-
"status": "healthy" if model_ready else "initializing",
|
| 62 |
-
"service": "LLM Pro Finance API",
|
| 63 |
-
"model_ready": model_ready,
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
|
| 67 |
-
@app.get("/
|
| 68 |
-
async def
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
|
|
|
|
| 1 |
+
"""Main FastAPI application entry point."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import threading
|
| 5 |
+
from typing import Dict
|
| 6 |
+
|
| 7 |
from fastapi import FastAPI
|
| 8 |
+
|
| 9 |
+
from app.config import settings
|
| 10 |
from app.middleware import api_key_guard
|
|
|
|
| 11 |
from app.routers import openai_api
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Configure logging
|
| 14 |
logging.basicConfig(level=logging.INFO)
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
+
app = FastAPI(
|
| 18 |
+
title="LLM Pro Finance API (Transformers)",
|
| 19 |
+
description="OpenAI-compatible API for financial LLM inference",
|
| 20 |
+
version="1.0.0"
|
| 21 |
+
)
|
| 22 |
|
| 23 |
# Mount routers
|
| 24 |
app.include_router(openai_api.router, prefix="/v1")
|
| 25 |
|
| 26 |
+
# Optional API key middleware
|
|
|
|
| 27 |
app.middleware("http")(api_key_guard)
|
| 28 |
|
| 29 |
+
|
| 30 |
@app.on_event("startup")
|
| 31 |
+
async def startup_event() -> None:
|
| 32 |
+
"""Startup event - initialize model in background thread.
|
| 33 |
+
|
| 34 |
+
Loads the model asynchronously to avoid blocking the API startup.
|
| 35 |
+
Model loading happens in a daemon thread so it doesn't prevent shutdown.
|
| 36 |
+
"""
|
| 37 |
logger.info("Starting LLM Pro Finance API...")
|
| 38 |
|
| 39 |
force_reload = settings.force_model_reload
|
|
|
|
| 42 |
|
| 43 |
logger.info("Initializing model in background thread...")
|
| 44 |
|
| 45 |
+
def load_model() -> None:
|
| 46 |
+
"""Load the model in a background thread."""
|
| 47 |
from app.providers.transformers_provider import initialize_model
|
| 48 |
initialize_model(force_reload=force_reload)
|
| 49 |
|
|
|
|
| 52 |
thread.start()
|
| 53 |
logger.info("Model initialization started in background")
|
| 54 |
|
| 55 |
+
|
| 56 |
@app.get("/")
|
| 57 |
async def root() -> Dict[str, str]:
|
| 58 |
+
"""Root endpoint returning API status and information.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Dictionary containing API status, service name, version, model, and backend.
|
| 62 |
+
"""
|
| 63 |
return {
|
| 64 |
"status": "ok",
|
| 65 |
"service": "Qwen Open Finance R 8B Inference",
|
| 66 |
"version": "1.0.0",
|
| 67 |
+
"model": settings.model,
|
| 68 |
"backend": "Transformers"
|
| 69 |
}
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
@app.get("/health")
|
| 73 |
+
async def health() -> Dict[str, str]:
|
| 74 |
+
"""Health check endpoint for monitoring and load balancers.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Dictionary with service health status.
|
| 78 |
+
"""
|
| 79 |
+
return {"status": "healthy", "service": "LLM Pro Finance API"}
|
| 80 |
|
| 81 |
|
app/middleware.py
CHANGED
|
@@ -1,26 +1,46 @@
|
|
| 1 |
-
from fastapi import Request
|
| 2 |
-
from fastapi.responses import JSONResponse
|
|
|
|
| 3 |
|
| 4 |
from app.config import settings
|
| 5 |
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
# Skip auth for public endpoints
|
| 12 |
-
if request.url.path in
|
| 13 |
return await call_next(request)
|
| 14 |
|
| 15 |
# Skip auth if no API key is configured
|
| 16 |
if not settings.service_api_key:
|
| 17 |
return await call_next(request)
|
| 18 |
|
| 19 |
-
# Check API key
|
| 20 |
-
|
| 21 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
return await call_next(request)
|
| 23 |
|
| 24 |
-
return JSONResponse(
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
|
|
|
| 1 |
+
from fastapi import Request
|
| 2 |
+
from fastapi.responses import JSONResponse, Response
|
| 3 |
+
from typing import Callable, Awaitable, Union
|
| 4 |
|
| 5 |
from app.config import settings
|
| 6 |
|
| 7 |
+
# Public endpoints that don't require authentication
|
| 8 |
+
PUBLIC_PATHS = frozenset(["/", "/health", "/docs", "/redoc", "/openapi.json"])
|
| 9 |
|
| 10 |
+
|
| 11 |
+
async def api_key_guard(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Union[Response, JSONResponse]:
|
| 12 |
+
"""
|
| 13 |
+
Middleware to protect API endpoints with optional API key authentication.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
request: FastAPI request object
|
| 17 |
+
call_next: Next middleware/handler in the chain
|
| 18 |
|
| 19 |
+
Returns:
|
| 20 |
+
Response from next handler or 401 if unauthorized
|
| 21 |
+
"""
|
| 22 |
# Skip auth for public endpoints
|
| 23 |
+
if request.url.path in PUBLIC_PATHS:
|
| 24 |
return await call_next(request)
|
| 25 |
|
| 26 |
# Skip auth if no API key is configured
|
| 27 |
if not settings.service_api_key:
|
| 28 |
return await call_next(request)
|
| 29 |
|
| 30 |
+
# Check API key from headers
|
| 31 |
+
api_key = request.headers.get("x-api-key")
|
| 32 |
+
if not api_key:
|
| 33 |
+
# Also check Authorization header with Bearer token
|
| 34 |
+
auth_header = request.headers.get("authorization", "")
|
| 35 |
+
if auth_header.startswith("Bearer "):
|
| 36 |
+
api_key = auth_header.replace("Bearer ", "").strip()
|
| 37 |
+
|
| 38 |
+
if api_key and api_key == settings.service_api_key:
|
| 39 |
return await call_next(request)
|
| 40 |
|
| 41 |
+
return JSONResponse(
|
| 42 |
+
content={"error": {"message": "unauthorized", "type": "authentication_error"}},
|
| 43 |
+
status_code=401
|
| 44 |
+
)
|
| 45 |
|
| 46 |
|
app/providers/base.py
CHANGED
|
@@ -1,11 +1,33 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
class LLMProvider(Protocol):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
async def list_models(self) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
...
|
| 7 |
-
|
| 8 |
async def chat(self, payload: Dict[str, Any], stream: bool = False) -> Any:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
...
|
| 10 |
|
| 11 |
|
|
|
|
| 1 |
+
"""Base protocol for LLM providers."""
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, Protocol
|
| 4 |
|
| 5 |
|
| 6 |
class LLMProvider(Protocol):
|
| 7 |
+
"""Protocol defining the interface for LLM providers.
|
| 8 |
+
|
| 9 |
+
Any class implementing this protocol must provide async methods
|
| 10 |
+
for listing models and generating chat completions.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
async def list_models(self) -> Dict[str, Any]:
|
| 14 |
+
"""List available models.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Dictionary containing model information.
|
| 18 |
+
"""
|
| 19 |
...
|
| 20 |
+
|
| 21 |
async def chat(self, payload: Dict[str, Any], stream: bool = False) -> Any:
|
| 22 |
+
"""Generate chat completion.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
payload: Request payload containing messages and parameters
|
| 26 |
+
stream: Whether to stream the response
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Chat completion response (varies by implementation)
|
| 30 |
+
"""
|
| 31 |
...
|
| 32 |
|
| 33 |
|
app/services/chat_service.py
CHANGED
|
@@ -1,13 +1,33 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
|
| 3 |
from app.providers import transformers_provider as provider
|
| 4 |
|
| 5 |
|
| 6 |
async def list_models() -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
return await provider.list_models()
|
| 8 |
|
| 9 |
|
| 10 |
-
async def chat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
return await provider.chat(payload, stream=stream)
|
| 12 |
|
| 13 |
|
|
|
|
| 1 |
+
"""Chat service layer providing abstraction over the provider."""
|
| 2 |
+
from typing import Any, Dict, Union, AsyncIterator
|
| 3 |
|
| 4 |
from app.providers import transformers_provider as provider
|
| 5 |
|
| 6 |
|
| 7 |
async def list_models() -> Dict[str, Any]:
|
| 8 |
+
"""
|
| 9 |
+
List available models.
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
Dictionary containing model list in OpenAI-compatible format
|
| 13 |
+
"""
|
| 14 |
return await provider.list_models()
|
| 15 |
|
| 16 |
|
| 17 |
+
async def chat(
|
| 18 |
+
payload: Dict[str, Any],
|
| 19 |
+
stream: bool = False
|
| 20 |
+
) -> Union[Dict[str, Any], AsyncIterator[str]]:
|
| 21 |
+
"""
|
| 22 |
+
Process chat completion request.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
payload: Request payload containing messages and generation parameters
|
| 26 |
+
stream: Whether to stream the response
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Response dictionary or async iterator for streaming
|
| 30 |
+
"""
|
| 31 |
return await provider.chat(payload, stream=stream)
|
| 32 |
|
| 33 |
|
app/utils/constants.py
CHANGED
|
@@ -1,18 +1,25 @@
|
|
| 1 |
-
"""Application-wide constants."""
|
| 2 |
|
| 3 |
import os
|
|
|
|
|
|
|
| 4 |
|
| 5 |
# Model configuration
|
| 6 |
-
MODEL_NAME = "DragonLLM/qwen3-8b-fin-v1.0"
|
| 7 |
|
| 8 |
# Cache directory - respect HF_HOME if set, otherwise use default
|
| 9 |
-
CACHE_DIR = os.getenv("HF_HOME", "/tmp/huggingface")
|
| 10 |
|
| 11 |
# Hugging Face token environment variable priority order
|
| 12 |
-
HF_TOKEN_VARS = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# French language detection patterns
|
| 15 |
-
FRENCH_PHRASES = [
|
| 16 |
"en français",
|
| 17 |
"répondez en français",
|
| 18 |
"réponse française",
|
|
@@ -20,9 +27,11 @@ FRENCH_PHRASES = [
|
|
| 20 |
"expliquez en français",
|
| 21 |
]
|
| 22 |
|
| 23 |
-
FRENCH_CHARS = [
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
FRENCH_PATTERNS = [
|
| 26 |
"qu'est-ce",
|
| 27 |
"qu'est",
|
| 28 |
"expliquez",
|
|
@@ -38,7 +47,7 @@ FRENCH_PATTERNS = [
|
|
| 38 |
"définissez",
|
| 39 |
]
|
| 40 |
|
| 41 |
-
FRENCH_SYSTEM_PROMPT = (
|
| 42 |
"Vous êtes un assistant financier expert. "
|
| 43 |
"Répondez TOUJOURS en français. "
|
| 44 |
"Soyez concis et précis dans vos explications. "
|
|
@@ -46,15 +55,15 @@ FRENCH_SYSTEM_PROMPT = (
|
|
| 46 |
)
|
| 47 |
|
| 48 |
# Qwen3 EOS tokens
|
| 49 |
-
EOS_TOKENS = [151645, 151643] # [<|im_end|>, <|endoftext|>]
|
| 50 |
-
PAD_TOKEN_ID = 151643 # <|endoftext|>
|
| 51 |
|
| 52 |
# Generation defaults
|
| 53 |
-
DEFAULT_MAX_TOKENS = 1000 # Increased for complete answers with concise reasoning
|
| 54 |
-
DEFAULT_TEMPERATURE = 0.7
|
| 55 |
-
DEFAULT_TOP_P = 1.0
|
| 56 |
-
DEFAULT_TOP_K = 20
|
| 57 |
-
REPETITION_PENALTY = 1.05
|
| 58 |
|
| 59 |
# Model initialization constants
|
| 60 |
MODEL_INIT_TIMEOUT_SECONDS = 300 # 5 minutes timeout for model initialization
|
|
|
|
| 1 |
+
"""Application-wide constants and configuration."""
|
| 2 |
|
| 3 |
import os
|
| 4 |
+
from typing import Final, List
|
| 5 |
+
|
| 6 |
|
| 7 |
# Model configuration
|
| 8 |
+
MODEL_NAME: Final[str] = "DragonLLM/qwen3-8b-fin-v1.0"
|
| 9 |
|
| 10 |
# Cache directory - respect HF_HOME if set, otherwise use default
|
| 11 |
+
CACHE_DIR: Final[str] = os.getenv("HF_HOME", "/tmp/huggingface")
|
| 12 |
|
| 13 |
# Hugging Face token environment variable priority order
|
| 14 |
+
HF_TOKEN_VARS: Final[List[str]] = [
|
| 15 |
+
"HF_TOKEN_LC2",
|
| 16 |
+
"HF_TOKEN_LC",
|
| 17 |
+
"HF_TOKEN",
|
| 18 |
+
"HUGGING_FACE_HUB_TOKEN"
|
| 19 |
+
]
|
| 20 |
|
| 21 |
# French language detection patterns
|
| 22 |
+
FRENCH_PHRASES: Final[List[str]] = [
|
| 23 |
"en français",
|
| 24 |
"répondez en français",
|
| 25 |
"réponse française",
|
|
|
|
| 27 |
"expliquez en français",
|
| 28 |
]
|
| 29 |
|
| 30 |
+
FRENCH_CHARS: Final[List[str]] = [
|
| 31 |
+
"é", "è", "ê", "à", "ç", "ù", "ô", "î", "â", "û", "ë", "ï"
|
| 32 |
+
]
|
| 33 |
|
| 34 |
+
FRENCH_PATTERNS: Final[List[str]] = [
|
| 35 |
"qu'est-ce",
|
| 36 |
"qu'est",
|
| 37 |
"expliquez",
|
|
|
|
| 47 |
"définissez",
|
| 48 |
]
|
| 49 |
|
| 50 |
+
FRENCH_SYSTEM_PROMPT: Final[str] = (
|
| 51 |
"Vous êtes un assistant financier expert. "
|
| 52 |
"Répondez TOUJOURS en français. "
|
| 53 |
"Soyez concis et précis dans vos explications. "
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
# Qwen3 EOS tokens
|
| 58 |
+
EOS_TOKENS: Final[List[int]] = [151645, 151643] # [<|im_end|>, <|endoftext|>]
|
| 59 |
+
PAD_TOKEN_ID: Final[int] = 151643 # <|endoftext|>
|
| 60 |
|
| 61 |
# Generation defaults
|
| 62 |
+
DEFAULT_MAX_TOKENS: Final[int] = 1000 # Increased for complete answers with concise reasoning
|
| 63 |
+
DEFAULT_TEMPERATURE: Final[float] = 0.7
|
| 64 |
+
DEFAULT_TOP_P: Final[float] = 1.0
|
| 65 |
+
DEFAULT_TOP_K: Final[int] = 20
|
| 66 |
+
REPETITION_PENALTY: Final[float] = 1.05
|
| 67 |
|
| 68 |
# Model initialization constants
|
| 69 |
MODEL_INIT_TIMEOUT_SECONDS = 300 # 5 minutes timeout for model initialization
|
app/utils/helpers.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import logging
|
| 5 |
-
from typing import Optional, Tuple
|
| 6 |
|
| 7 |
from app.utils.constants import HF_TOKEN_VARS, FRENCH_PHRASES, FRENCH_CHARS, FRENCH_PATTERNS
|
| 8 |
|
|
@@ -24,7 +24,7 @@ def get_hf_token() -> Tuple[Optional[str], str]:
|
|
| 24 |
return None, "none"
|
| 25 |
|
| 26 |
|
| 27 |
-
def is_french_request(messages:
|
| 28 |
"""
|
| 29 |
Detect if the request is in French based on user messages.
|
| 30 |
|
|
@@ -55,7 +55,7 @@ def is_french_request(messages: list) -> bool:
|
|
| 55 |
return False
|
| 56 |
|
| 57 |
|
| 58 |
-
def has_french_system_prompt(messages:
|
| 59 |
"""Check if messages already contain a French system prompt."""
|
| 60 |
return any(
|
| 61 |
"français" in msg.get("content", "").lower()
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import logging
|
| 5 |
+
from typing import Optional, Tuple, List, Dict, Any
|
| 6 |
|
| 7 |
from app.utils.constants import HF_TOKEN_VARS, FRENCH_PHRASES, FRENCH_CHARS, FRENCH_PATTERNS
|
| 8 |
|
|
|
|
| 24 |
return None, "none"
|
| 25 |
|
| 26 |
|
| 27 |
+
def is_french_request(messages: List[Dict[str, Any]]) -> bool:
|
| 28 |
"""
|
| 29 |
Detect if the request is in French based on user messages.
|
| 30 |
|
|
|
|
| 55 |
return False
|
| 56 |
|
| 57 |
|
| 58 |
+
def has_french_system_prompt(messages: List[Dict[str, Any]]) -> bool:
|
| 59 |
"""Check if messages already contain a French system prompt."""
|
| 60 |
return any(
|
| 61 |
"français" in msg.get("content", "").lower()
|
app/utils/memory.py
CHANGED
|
@@ -1,12 +1,23 @@
|
|
| 1 |
"""GPU memory management utilities."""
|
| 2 |
|
| 3 |
import gc
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
-
from typing import Optional
|
| 6 |
|
| 7 |
|
| 8 |
-
def clear_gpu_memory(model=None, tokenizer=None):
|
| 9 |
-
"""Clear GPU memory completely.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
if not torch.cuda.is_available():
|
| 11 |
return
|
| 12 |
|
|
|
|
| 1 |
"""GPU memory management utilities."""
|
| 2 |
|
| 3 |
import gc
|
| 4 |
+
from typing import Optional, Any
|
| 5 |
+
|
| 6 |
import torch
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
+
def clear_gpu_memory(model: Optional[Any] = None, tokenizer: Optional[Any] = None) -> None:
|
| 10 |
+
"""Clear GPU memory completely.
|
| 11 |
+
|
| 12 |
+
This function performs aggressive GPU memory cleanup by:
|
| 13 |
+
1. Deleting model and tokenizer objects if provided
|
| 14 |
+
2. Clearing CUDA cache
|
| 15 |
+
3. Running multiple garbage collection passes
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
model: Optional model object to delete
|
| 19 |
+
tokenizer: Optional tokenizer object to delete
|
| 20 |
+
"""
|
| 21 |
if not torch.cuda.is_available():
|
| 22 |
return
|
| 23 |
|