File size: 4,521 Bytes
0231daa
 
 
 
 
 
 
 
 
 
 
 
2e5859f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0231daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e5859f
0231daa
 
 
 
 
 
 
 
 
 
 
 
 
 
2e5859f
155ad69
 
 
 
 
2e5859f
 
 
 
 
0231daa
 
 
 
 
 
 
155ad69
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Input validation utilities.

This module provides validation functions for request inputs,
ensuring data quality and preventing abuse.
"""

from typing import List, Dict, Any
from pydantic import BaseModel
from src.core.exceptions import TextTooLongError, BatchTooLargeError, ValidationError


def validate_text(text: str, max_length: int = 8192, allow_empty: bool = False) -> None:
    """
    Validate a single text input.

    Args:
        text: Input text to validate
        max_length: Maximum allowed text length
        allow_empty: Whether to allow empty strings

    Raises:
        ValidationError: If text is empty and not allowed
        TextTooLongError: If text exceeds max_length
    """
    if not allow_empty and not text.strip():
        raise ValidationError("text", "Text cannot be empty")

    if len(text) > max_length:
        raise TextTooLongError(len(text), max_length)


def validate_texts(
    texts: List[str],
    max_length: int = 8192,
    max_batch_size: int = 100,
    allow_empty: bool = False,
) -> None:
    """
    Validate a list of text inputs.

    Args:
        texts: List of texts to validate
        max_length: Maximum allowed length per text
        max_batch_size: Maximum number of texts in batch
        allow_empty: Whether to allow empty strings

    Raises:
        ValidationError: If texts list is empty or contains invalid items
        BatchTooLargeError: If batch size exceeds max_batch_size
        TextTooLongError: If any text exceeds max_length
    """
    if not texts:
        raise ValidationError("texts", "Texts list cannot be empty")

    if len(texts) > max_batch_size:
        raise BatchTooLargeError(len(texts), max_batch_size)

    # Validate each text
    for idx, text in enumerate(texts):
        if not isinstance(text, str):
            raise ValidationError(
                f"texts[{idx}]", f"Expected string, got {type(text).__name__}"
            )

        if not allow_empty and not text.strip():
            raise ValidationError(f"texts[{idx}]", "Text cannot be empty")

        if len(text) > max_length:
            raise TextTooLongError(len(text), max_length)


def validate_model_id(model_id: str, available_models: List[str]) -> None:
    """
    Validate that a model_id exists in available models.

    Args:
        model_id: Model identifier to validate
        available_models: List of available model IDs

    Raises:
        ValidationError: If model_id is invalid
    """
    if not model_id:
        raise ValidationError("model_id", "Model ID cannot be empty")

    if model_id not in available_models:
        raise ValidationError(
            "model_id",
            f"Model '{model_id}' not found. Available: {', '.join(available_models)}",
        )


def extract_embedding_kwargs(request: BaseModel) -> Dict[str, Any]:
    """
    Extract embedding kwargs from a request object.

    This function extracts both the 'options' field and any extra fields
    passed in the request, combining them into a single kwargs dict.

    Args:
        request: Pydantic request model (EmbedRequest or BatchEmbedRequest)

    Returns:
        Dictionary of kwargs to pass to embedding model

    Example:
        >>> request = EmbedRequest(
        ...     texts=["hello"],
        ...     model_id="qwen3-0.6b",
        ...     options=EmbeddingOptions(normalize_embeddings=True),
        ...     batch_size=32  # Extra field
        ... )
        >>> extract_embedding_kwargs(request)
        {'normalize_embeddings': True, 'batch_size': 32}
    """
    kwargs = {}

    # Extract from 'options' field if present
    if hasattr(request, "options") and request.options is not None:
        kwargs.update(request.options.to_kwargs())

    # Extract extra fields (excluding standard fields)
    standard_fields = {
        "input",
        "model",
        "encoding_format",
        "dimensions",
        "user",
        "options",
        "query",
        "documents",
        "top_k",
    }
    request_dict = request.model_dump()

    for key, value in request_dict.items():
        if key not in standard_fields and value is not None:
            kwargs[key] = value

    return kwargs


def estimate_tokens(text: str) -> int:
    """Estimate token count (simple approximation)."""
    # Simple heuristic: ~4 characters per token
    return max(1, len(text) // 4)


def count_tokens_batch(texts: List[str]) -> int:
    """Count tokens for batch of texts."""
    return sum(estimate_tokens(text) for text in texts)