refactor: Enhance codebase with comprehensive improvements for CodeRabbit review
Browse files- app/config.py: Add Pydantic Field descriptions, type validation, and docstrings
- Use Literal type for log_level validation
- Add comprehensive class and field documentation
- app/main.py: Improve structure and documentation
- Add module docstring
- Enhance FastAPI app configuration with description and version
- Add comprehensive docstrings to all functions
- Use settings.model instead of hardcoded model name
- app/utils/memory.py: Add type hints and comprehensive docstring
- Add Optional[Any] type hints for model/tokenizer parameters
- Add return type annotation (None)
- Document function behavior and cleanup process
- app/models/openai.py: Add Pydantic Field validation and docstrings
- Add comprehensive docstrings to all model classes
- Add Field descriptions and validation (ge, le constraints)
- Improve type safety with proper Field annotations
- app/providers/base.py: Add protocol documentation
- Add module docstring
- Add comprehensive docstrings to LLMProvider protocol
- app/utils/constants.py: Add Final type hints for immutability
- Use typing.Final for all constants to indicate immutability
- Improve code clarity and type safety
These changes expand CodeRabbit review coverage to include config, models,
base providers, and utility modules that were not previously reviewed.
- app/config.py +26 -4
- app/main.py +34 -10
- app/models/openai.py +70 -21
- app/providers/base.py +24 -2
- app/utils/constants.py +24 -15
- app/utils/memory.py +14 -3
|
@@ -1,11 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 2 |
|
| 3 |
|
| 4 |
class Settings(BaseSettings):
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
model_config = SettingsConfigDict(
|
| 11 |
env_file=".env",
|
|
|
|
| 1 |
+
"""Application configuration using Pydantic settings."""
|
| 2 |
+
|
| 3 |
+
from typing import Literal
|
| 4 |
+
from pydantic import Field
|
| 5 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 6 |
|
| 7 |
|
| 8 |
class Settings(BaseSettings):
|
| 9 |
+
"""Application settings loaded from environment variables.
|
| 10 |
+
|
| 11 |
+
Supports loading from .env file with UTF-8 encoding.
|
| 12 |
+
All settings can be overridden via environment variables.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
model: str = Field(
|
| 16 |
+
default="DragonLLM/qwen3-8b-fin-v1.0",
|
| 17 |
+
description="Hugging Face model identifier"
|
| 18 |
+
)
|
| 19 |
+
service_api_key: str | None = Field(
|
| 20 |
+
default=None,
|
| 21 |
+
description="Optional API key for authentication (SERVICE_API_KEY env var)"
|
| 22 |
+
)
|
| 23 |
+
log_level: Literal["debug", "info", "warning", "error"] = Field(
|
| 24 |
+
default="info",
|
| 25 |
+
description="Logging level"
|
| 26 |
+
)
|
| 27 |
+
force_model_reload: bool = Field(
|
| 28 |
+
default=False,
|
| 29 |
+
description="Force model reload from Hugging Face, bypassing cache (FORCE_MODEL_RELOAD env var)"
|
| 30 |
+
)
|
| 31 |
|
| 32 |
model_config = SettingsConfigDict(
|
| 33 |
env_file=".env",
|
|
@@ -1,15 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Dict
|
|
|
|
| 2 |
from fastapi import FastAPI
|
|
|
|
|
|
|
| 3 |
from app.middleware import api_key_guard
|
| 4 |
from app.routers import openai_api
|
| 5 |
-
from app.config import settings
|
| 6 |
-
import logging
|
| 7 |
|
| 8 |
# Configure logging
|
| 9 |
logging.basicConfig(level=logging.INFO)
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
-
app = FastAPI(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Mount routers
|
| 15 |
app.include_router(openai_api.router, prefix="/v1")
|
|
@@ -17,10 +26,14 @@ app.include_router(openai_api.router, prefix="/v1")
|
|
| 17 |
# Optional API key middleware
|
| 18 |
app.middleware("http")(api_key_guard)
|
| 19 |
|
|
|
|
| 20 |
@app.on_event("startup")
|
| 21 |
-
async def startup_event():
|
| 22 |
-
"""Startup event - initialize model in background
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
| 24 |
logger.info("Starting LLM Pro Finance API...")
|
| 25 |
|
| 26 |
force_reload = settings.force_model_reload
|
|
@@ -29,7 +42,8 @@ async def startup_event():
|
|
| 29 |
|
| 30 |
logger.info("Initializing model in background thread...")
|
| 31 |
|
| 32 |
-
def load_model():
|
|
|
|
| 33 |
from app.providers.transformers_provider import initialize_model
|
| 34 |
initialize_model(force_reload=force_reload)
|
| 35 |
|
|
@@ -38,20 +52,30 @@ async def startup_event():
|
|
| 38 |
thread.start()
|
| 39 |
logger.info("Model initialization started in background")
|
| 40 |
|
|
|
|
| 41 |
@app.get("/")
|
| 42 |
async def root() -> Dict[str, str]:
|
| 43 |
-
"""Root endpoint returning API status and information.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
return {
|
| 45 |
"status": "ok",
|
| 46 |
"service": "Qwen Open Finance R 8B Inference",
|
| 47 |
"version": "1.0.0",
|
| 48 |
-
"model":
|
| 49 |
"backend": "Transformers"
|
| 50 |
}
|
| 51 |
|
|
|
|
| 52 |
@app.get("/health")
|
| 53 |
async def health() -> Dict[str, str]:
|
| 54 |
-
"""Health check endpoint.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
return {"status": "healthy", "service": "LLM Pro Finance API"}
|
| 56 |
|
| 57 |
|
|
|
|
| 1 |
+
"""Main FastAPI application entry point."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import threading
|
| 5 |
from typing import Dict
|
| 6 |
+
|
| 7 |
from fastapi import FastAPI
|
| 8 |
+
|
| 9 |
+
from app.config import settings
|
| 10 |
from app.middleware import api_key_guard
|
| 11 |
from app.routers import openai_api
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Configure logging
|
| 14 |
logging.basicConfig(level=logging.INFO)
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
+
app = FastAPI(
|
| 18 |
+
title="LLM Pro Finance API (Transformers)",
|
| 19 |
+
description="OpenAI-compatible API for financial LLM inference",
|
| 20 |
+
version="1.0.0"
|
| 21 |
+
)
|
| 22 |
|
| 23 |
# Mount routers
|
| 24 |
app.include_router(openai_api.router, prefix="/v1")
|
|
|
|
| 26 |
# Optional API key middleware
|
| 27 |
app.middleware("http")(api_key_guard)
|
| 28 |
|
| 29 |
+
|
| 30 |
@app.on_event("startup")
|
| 31 |
+
async def startup_event() -> None:
|
| 32 |
+
"""Startup event - initialize model in background thread.
|
| 33 |
+
|
| 34 |
+
Loads the model asynchronously to avoid blocking the API startup.
|
| 35 |
+
Model loading happens in a daemon thread so it doesn't prevent shutdown.
|
| 36 |
+
"""
|
| 37 |
logger.info("Starting LLM Pro Finance API...")
|
| 38 |
|
| 39 |
force_reload = settings.force_model_reload
|
|
|
|
| 42 |
|
| 43 |
logger.info("Initializing model in background thread...")
|
| 44 |
|
| 45 |
+
def load_model() -> None:
|
| 46 |
+
"""Load the model in a background thread."""
|
| 47 |
from app.providers.transformers_provider import initialize_model
|
| 48 |
initialize_model(force_reload=force_reload)
|
| 49 |
|
|
|
|
| 52 |
thread.start()
|
| 53 |
logger.info("Model initialization started in background")
|
| 54 |
|
| 55 |
+
|
| 56 |
@app.get("/")
|
| 57 |
async def root() -> Dict[str, str]:
|
| 58 |
+
"""Root endpoint returning API status and information.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Dictionary containing API status, service name, version, model, and backend.
|
| 62 |
+
"""
|
| 63 |
return {
|
| 64 |
"status": "ok",
|
| 65 |
"service": "Qwen Open Finance R 8B Inference",
|
| 66 |
"version": "1.0.0",
|
| 67 |
+
"model": settings.model,
|
| 68 |
"backend": "Transformers"
|
| 69 |
}
|
| 70 |
|
| 71 |
+
|
| 72 |
@app.get("/health")
|
| 73 |
async def health() -> Dict[str, str]:
|
| 74 |
+
"""Health check endpoint for monitoring and load balancers.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Dictionary with service health status.
|
| 78 |
+
"""
|
| 79 |
return {"status": "healthy", "service": "LLM Pro Finance API"}
|
| 80 |
|
| 81 |
|
|
@@ -1,4 +1,7 @@
|
|
|
|
|
|
|
|
| 1 |
from typing import List, Literal, Optional
|
|
|
|
| 2 |
from pydantic import BaseModel, Field
|
| 3 |
|
| 4 |
|
|
@@ -6,42 +9,88 @@ Role = Literal["system", "user", "assistant", "tool"]
|
|
| 6 |
|
| 7 |
|
| 8 |
class Message(BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
role: Role
|
| 10 |
-
content: str
|
| 11 |
|
| 12 |
|
| 13 |
class ChatCompletionRequest(BaseModel):
|
| 14 |
-
model
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class ChoiceMessage(BaseModel):
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class Choice(BaseModel):
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
class Usage(BaseModel):
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
class ChatCompletionResponse(BaseModel):
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
|
|
|
| 1 |
+
"""OpenAI-compatible API models using Pydantic."""
|
| 2 |
+
|
| 3 |
from typing import List, Literal, Optional
|
| 4 |
+
|
| 5 |
from pydantic import BaseModel, Field
|
| 6 |
|
| 7 |
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class Message(BaseModel):
|
| 12 |
+
"""A single message in a conversation.
|
| 13 |
+
|
| 14 |
+
Attributes:
|
| 15 |
+
role: The role of the message sender
|
| 16 |
+
content: The text content of the message
|
| 17 |
+
"""
|
| 18 |
role: Role
|
| 19 |
+
content: str = Field(..., description="Message content")
|
| 20 |
|
| 21 |
|
| 22 |
class ChatCompletionRequest(BaseModel):
|
| 23 |
+
"""Request model for chat completions endpoint.
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
model: Optional model identifier (uses default from config if not provided)
|
| 27 |
+
messages: List of messages in the conversation
|
| 28 |
+
temperature: Sampling temperature (0-2)
|
| 29 |
+
max_tokens: Maximum tokens to generate
|
| 30 |
+
stream: Whether to stream the response
|
| 31 |
+
top_p: Nucleus sampling parameter
|
| 32 |
+
"""
|
| 33 |
+
model: Optional[str] = Field(default=None, description="Model identifier")
|
| 34 |
+
messages: List[Message] = Field(..., description="Conversation messages")
|
| 35 |
+
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
|
| 36 |
+
max_tokens: Optional[int] = Field(default=None, ge=1, description="Maximum tokens to generate")
|
| 37 |
+
stream: Optional[bool] = Field(default=False, description="Stream response")
|
| 38 |
+
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter")
|
| 39 |
|
| 40 |
|
| 41 |
class ChoiceMessage(BaseModel):
|
| 42 |
+
"""Assistant message in a completion choice.
|
| 43 |
+
|
| 44 |
+
Attributes:
|
| 45 |
+
role: Always "assistant" for completion messages
|
| 46 |
+
content: The generated message content
|
| 47 |
+
"""
|
| 48 |
+
role: Literal["assistant"] = "assistant"
|
| 49 |
+
content: Optional[str] = Field(default=None, description="Generated message content")
|
| 50 |
|
| 51 |
|
| 52 |
class Choice(BaseModel):
|
| 53 |
+
"""A single completion choice.
|
| 54 |
+
|
| 55 |
+
Attributes:
|
| 56 |
+
index: Choice index
|
| 57 |
+
message: The generated message
|
| 58 |
+
finish_reason: Reason why generation finished (stop, length, etc.)
|
| 59 |
+
"""
|
| 60 |
+
index: int = Field(..., description="Choice index")
|
| 61 |
+
message: ChoiceMessage = Field(..., description="Generated message")
|
| 62 |
+
finish_reason: Optional[str] = Field(default=None, description="Reason for completion")
|
| 63 |
|
| 64 |
|
| 65 |
class Usage(BaseModel):
|
| 66 |
+
"""Token usage statistics.
|
| 67 |
+
|
| 68 |
+
Attributes:
|
| 69 |
+
prompt_tokens: Number of tokens in the prompt
|
| 70 |
+
completion_tokens: Number of tokens in the completion
|
| 71 |
+
total_tokens: Total tokens used
|
| 72 |
+
"""
|
| 73 |
+
prompt_tokens: int = Field(..., ge=0, description="Tokens in prompt")
|
| 74 |
+
completion_tokens: int = Field(..., ge=0, description="Tokens in completion")
|
| 75 |
+
total_tokens: int = Field(..., ge=0, description="Total tokens used")
|
| 76 |
|
| 77 |
|
| 78 |
class ChatCompletionResponse(BaseModel):
|
| 79 |
+
"""Response model for chat completions endpoint.
|
| 80 |
+
|
| 81 |
+
Attributes:
|
| 82 |
+
id: Unique completion ID
|
| 83 |
+
object: Always "chat.completion"
|
| 84 |
+
created: Unix timestamp of creation
|
| 85 |
+
model: Model identifier used
|
| 86 |
+
choices: List of completion choices
|
| 87 |
+
usage: Optional token usage statistics
|
| 88 |
+
"""
|
| 89 |
+
id: str = Field(..., description="Completion ID")
|
| 90 |
+
object: Literal["chat.completion"] = Field(default="chat.completion", description="Object type")
|
| 91 |
+
created: int = Field(..., description="Unix timestamp")
|
| 92 |
+
model: str = Field(..., description="Model identifier")
|
| 93 |
+
choices: List[Choice] = Field(..., description="Completion choices")
|
| 94 |
+
usage: Optional[Usage] = Field(default=None, description="Token usage statistics")
|
| 95 |
|
| 96 |
|
|
@@ -1,11 +1,33 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
class LLMProvider(Protocol):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
async def list_models(self) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
...
|
| 7 |
-
|
| 8 |
async def chat(self, payload: Dict[str, Any], stream: bool = False) -> Any:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
...
|
| 10 |
|
| 11 |
|
|
|
|
| 1 |
+
"""Base protocol for LLM providers."""
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, Protocol
|
| 4 |
|
| 5 |
|
| 6 |
class LLMProvider(Protocol):
|
| 7 |
+
"""Protocol defining the interface for LLM providers.
|
| 8 |
+
|
| 9 |
+
Any class implementing this protocol must provide async methods
|
| 10 |
+
for listing models and generating chat completions.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
async def list_models(self) -> Dict[str, Any]:
|
| 14 |
+
"""List available models.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Dictionary containing model information.
|
| 18 |
+
"""
|
| 19 |
...
|
| 20 |
+
|
| 21 |
async def chat(self, payload: Dict[str, Any], stream: bool = False) -> Any:
|
| 22 |
+
"""Generate chat completion.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
payload: Request payload containing messages and parameters
|
| 26 |
+
stream: Whether to stream the response
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Chat completion response (varies by implementation)
|
| 30 |
+
"""
|
| 31 |
...
|
| 32 |
|
| 33 |
|
|
@@ -1,18 +1,25 @@
|
|
| 1 |
-
"""Application-wide constants."""
|
| 2 |
|
| 3 |
import os
|
|
|
|
|
|
|
| 4 |
|
| 5 |
# Model configuration
|
| 6 |
-
MODEL_NAME = "DragonLLM/qwen3-8b-fin-v1.0"
|
| 7 |
|
| 8 |
# Cache directory - respect HF_HOME if set, otherwise use default
|
| 9 |
-
CACHE_DIR = os.getenv("HF_HOME", "/tmp/huggingface")
|
| 10 |
|
| 11 |
# Hugging Face token environment variable priority order
|
| 12 |
-
HF_TOKEN_VARS = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# French language detection patterns
|
| 15 |
-
FRENCH_PHRASES = [
|
| 16 |
"en français",
|
| 17 |
"répondez en français",
|
| 18 |
"réponse française",
|
|
@@ -20,9 +27,11 @@ FRENCH_PHRASES = [
|
|
| 20 |
"expliquez en français",
|
| 21 |
]
|
| 22 |
|
| 23 |
-
FRENCH_CHARS = [
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
FRENCH_PATTERNS = [
|
| 26 |
"qu'est-ce",
|
| 27 |
"qu'est",
|
| 28 |
"expliquez",
|
|
@@ -38,7 +47,7 @@ FRENCH_PATTERNS = [
|
|
| 38 |
"définissez",
|
| 39 |
]
|
| 40 |
|
| 41 |
-
FRENCH_SYSTEM_PROMPT = (
|
| 42 |
"Vous êtes un assistant financier expert. "
|
| 43 |
"Répondez TOUJOURS en français. "
|
| 44 |
"Soyez concis et précis dans vos explications. "
|
|
@@ -46,13 +55,13 @@ FRENCH_SYSTEM_PROMPT = (
|
|
| 46 |
)
|
| 47 |
|
| 48 |
# Qwen3 EOS tokens
|
| 49 |
-
EOS_TOKENS = [151645, 151643] # [<|im_end|>, <|endoftext|>]
|
| 50 |
-
PAD_TOKEN_ID = 151643 # <|endoftext|>
|
| 51 |
|
| 52 |
# Generation defaults
|
| 53 |
-
DEFAULT_MAX_TOKENS = 1000 # Increased for complete answers with concise reasoning
|
| 54 |
-
DEFAULT_TEMPERATURE = 0.7
|
| 55 |
-
DEFAULT_TOP_P = 1.0
|
| 56 |
-
DEFAULT_TOP_K = 20
|
| 57 |
-
REPETITION_PENALTY = 1.05
|
| 58 |
|
|
|
|
| 1 |
+
"""Application-wide constants and configuration."""
|
| 2 |
|
| 3 |
import os
|
| 4 |
+
from typing import Final, List
|
| 5 |
+
|
| 6 |
|
| 7 |
# Model configuration
|
| 8 |
+
MODEL_NAME: Final[str] = "DragonLLM/qwen3-8b-fin-v1.0"
|
| 9 |
|
| 10 |
# Cache directory - respect HF_HOME if set, otherwise use default
|
| 11 |
+
CACHE_DIR: Final[str] = os.getenv("HF_HOME", "/tmp/huggingface")
|
| 12 |
|
| 13 |
# Hugging Face token environment variable priority order
|
| 14 |
+
HF_TOKEN_VARS: Final[List[str]] = [
|
| 15 |
+
"HF_TOKEN_LC2",
|
| 16 |
+
"HF_TOKEN_LC",
|
| 17 |
+
"HF_TOKEN",
|
| 18 |
+
"HUGGING_FACE_HUB_TOKEN"
|
| 19 |
+
]
|
| 20 |
|
| 21 |
# French language detection patterns
|
| 22 |
+
FRENCH_PHRASES: Final[List[str]] = [
|
| 23 |
"en français",
|
| 24 |
"répondez en français",
|
| 25 |
"réponse française",
|
|
|
|
| 27 |
"expliquez en français",
|
| 28 |
]
|
| 29 |
|
| 30 |
+
FRENCH_CHARS: Final[List[str]] = [
|
| 31 |
+
"é", "è", "ê", "à", "ç", "ù", "ô", "î", "â", "û", "ë", "ï"
|
| 32 |
+
]
|
| 33 |
|
| 34 |
+
FRENCH_PATTERNS: Final[List[str]] = [
|
| 35 |
"qu'est-ce",
|
| 36 |
"qu'est",
|
| 37 |
"expliquez",
|
|
|
|
| 47 |
"définissez",
|
| 48 |
]
|
| 49 |
|
| 50 |
+
FRENCH_SYSTEM_PROMPT: Final[str] = (
|
| 51 |
"Vous êtes un assistant financier expert. "
|
| 52 |
"Répondez TOUJOURS en français. "
|
| 53 |
"Soyez concis et précis dans vos explications. "
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
# Qwen3 EOS tokens
|
| 58 |
+
EOS_TOKENS: Final[List[int]] = [151645, 151643] # [<|im_end|>, <|endoftext|>]
|
| 59 |
+
PAD_TOKEN_ID: Final[int] = 151643 # <|endoftext|>
|
| 60 |
|
| 61 |
# Generation defaults
|
| 62 |
+
DEFAULT_MAX_TOKENS: Final[int] = 1000 # Increased for complete answers with concise reasoning
|
| 63 |
+
DEFAULT_TEMPERATURE: Final[float] = 0.7
|
| 64 |
+
DEFAULT_TOP_P: Final[float] = 1.0
|
| 65 |
+
DEFAULT_TOP_K: Final[int] = 20
|
| 66 |
+
REPETITION_PENALTY: Final[float] = 1.05
|
| 67 |
|
|
@@ -1,12 +1,23 @@
|
|
| 1 |
"""GPU memory management utilities."""
|
| 2 |
|
| 3 |
import gc
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
-
from typing import Optional
|
| 6 |
|
| 7 |
|
| 8 |
-
def clear_gpu_memory(model=None, tokenizer=None):
|
| 9 |
-
"""Clear GPU memory completely.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
if not torch.cuda.is_available():
|
| 11 |
return
|
| 12 |
|
|
|
|
| 1 |
"""GPU memory management utilities."""
|
| 2 |
|
| 3 |
import gc
|
| 4 |
+
from typing import Optional, Any
|
| 5 |
+
|
| 6 |
import torch
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
+
def clear_gpu_memory(model: Optional[Any] = None, tokenizer: Optional[Any] = None) -> None:
|
| 10 |
+
"""Clear GPU memory completely.
|
| 11 |
+
|
| 12 |
+
This function performs aggressive GPU memory cleanup by:
|
| 13 |
+
1. Deleting model and tokenizer objects if provided
|
| 14 |
+
2. Clearing CUDA cache
|
| 15 |
+
3. Running multiple garbage collection passes
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
model: Optional model object to delete
|
| 19 |
+
tokenizer: Optional tokenizer object to delete
|
| 20 |
+
"""
|
| 21 |
if not torch.cuda.is_available():
|
| 22 |
return
|
| 23 |
|