File size: 14,942 Bytes
b7d2408 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
"""CVAT API client base class.
This module provides the main CvatApiClient class that handles authentication
and delegates to specialized method groups.
"""
import json
import logging
import time
from typing import Any
import requests
from .annotations import AnnotationsMethods
from .auth import AuthMethods
from .downloads import DownloadsMethods
from .jobs import JobsMethods
from .labels import LabelsMethods
from .projects import ProjectsMethods
from .tasks import TasksMethods
class CvatApiClient:
"""CVAT API client for interacting with CVAT endpoints.
This client provides organized methods for working with CVAT API endpoints,
including authentication, projects, tasks, jobs, annotations, and downloads.
The client is organized into method groups:
**Authentication Methods (client.auth.xxx)**
- get_auth_token() -> str
Authenticate and get authentication token
**Project Methods (client.projects.xxx)**
- get_project_details(project_id, token=None) -> CvatApiProjectDetails
Get project information and metadata
- get_project_labels(project_id, token=None) -> list[CvatApiLabelDefinition]
Get all label definitions for a project
- get_project_job_ids(project_id, token=None) -> list[int]
Get all job IDs associated with a project
- get_project_tasks(project_id, token=None) -> list[int]
Get all task IDs for a project
**Task Methods (client.tasks.xxx)**
- get_task_details(task_id, token=None) -> CvatApiTaskDetails
Get task information and metadata
- get_task_media_metainformation(task_id, token=None) -> CvatApiTaskMediasMetainformation
Get media metadata (frame info, dimensions) for a task
- get_task_job_ids(task_id, token=None) -> list[int]
Get all job IDs for a task
- update_task(task_id, task_data, token=None) -> CvatApiTaskDetails
Update task details (name, assignee, labels, etc.)
**Job Methods (client.jobs.xxx)**
- list_jobs(request, token=None) -> CvatApiJobsListResponse
List jobs with filtering (by project, task, stage, state, assignee, etc.)
- get_job_details(job_id, token=None) -> CvatApiJobDetails
Get job information and metadata
- get_job_media_metainformation(job_id, token=None) -> CvatApiJobMediasMetainformation
Get media metadata (frame info, dimensions) for a job
- update_job(job_id, job_data, token=None) -> CvatApiJobDetails
Update job details (status, assignee, stage, etc.)
**Annotation Methods (client.annotations.xxx)**
- get_job_annotations(job_id, token=None) -> CvatApiJobList
Get all annotations (shapes, tracks, tags) for a job
- get_task_annotations(task_id, token=None) -> CvatApiTaskList
Get all annotations (shapes, tracks, tags) for a task
- put_job_annotations(job_id, annotations, token=None) -> CvatApiJobList
Update annotations for a job
- put_task_annotations(task_id, annotations, token=None) -> CvatApiTaskList
Update annotations for a task
**Download Methods (client.downloads.xxx)**
- download_job_image(job_id, frame_number, output_path, token=None) -> None
Download a single frame image from a job
- download_task_image(task_id, frame_number, output_path, token=None) -> None
Download a single frame image from a task
- download_job_images(job_id, output_dir, token=None) -> list[Path]
Download all images from a job
- download_task_images(task_id, output_dir, token=None) -> list[Path]
Download all images from a task
- download_job_data_chunk(job_id, chunk_number, output_path, token=None) -> None
Download a specific data chunk from a job
- download_task_data_chunk(task_id, chunk_number, output_path, token=None) -> None
Download a specific data chunk from a task
- download_all_job_chunks(job_id, output_dir, token=None) -> list[Path]
Download all data chunks from a job
- download_all_task_chunks(task_id, output_dir, token=None) -> list[Path]
Download all data chunks from a task
**Label Methods (client.labels.xxx)**
- get_label_details(label_id, token=None) -> CvatApiLabelDefinition
Get label definition details
Attributes:
cvat_host: Base URL of the CVAT server
cvat_username: Username for authentication
cvat_password: Password for authentication
cvat_organization: Organization name
cvat_auth_timeout: Timeout for authentication requests in seconds
cvat_api_timeout: Timeout for API requests in seconds
cvat_token: Authentication token (set after initialization)
Example:
>>> client = CvatApiClient(
... cvat_host="https://cvat.example.com",
... cvat_username="user",
... cvat_password="pass",
... cvat_organization="my-org"
... )
>>> # Get project information
>>> project = client.projects.get_project_details(12345)
>>> # Get all jobs in a project
>>> job_ids = client.projects.get_project_job_ids(12345)
>>> # Download task annotations
>>> annotations = client.annotations.get_task_annotations(67890)
>>> # Download all images from a job
>>> paths = client.downloads.download_job_images(111, Path("./output"))
"""
def __init__(
self,
cvat_host: str,
cvat_username: str,
cvat_password: str,
cvat_organization: str,
cvat_auth_timeout: float = 30.0,
cvat_api_timeout: float = 60.0,
max_retries: int = 5,
initial_retry_delay: float = 1.0,
max_retry_delay: float = 60.0,
):
"""Initialize the CVAT API client.
Args:
cvat_host: Base URL of the CVAT server
cvat_username: Username for authentication
cvat_password: Password for authentication
cvat_organization: Organization name
cvat_auth_timeout: Timeout for authentication requests in seconds
cvat_api_timeout: Timeout for API requests in seconds
max_retries: Maximum number of retry attempts for transient errors
initial_retry_delay: Initial delay in seconds before first retry
max_retry_delay: Maximum delay in seconds between retries
"""
self.cvat_host = cvat_host
self.cvat_username = cvat_username
self.cvat_password = cvat_password
self.cvat_organization = cvat_organization
self.cvat_auth_timeout = cvat_auth_timeout
self.cvat_api_timeout = cvat_api_timeout
self.max_retries = max_retries
self.initial_retry_delay = initial_retry_delay
self.max_retry_delay = max_retry_delay
self.logger = logging.getLogger(__name__)
# Initialize method groups
self.auth = AuthMethods(self)
self.projects = ProjectsMethods(self)
self.tasks = TasksMethods(self)
self.jobs = JobsMethods(self)
self.annotations = AnnotationsMethods(self)
self.downloads = DownloadsMethods(self)
self.labels = LabelsMethods(self)
# Authenticate
self.cvat_token = self.auth.get_auth_token()
self.logger.info("🔑 CVAT API initialized.")
# 🔹 Helper methods for common patterns
def _get_headers(
self, token: str | None = None, with_organization: bool = True
) -> dict[str, str]:
"""Create standard headers for API requests.
Args:
token: Authentication token (uses self.cvat_token if None)
with_organization: Whether to include X-Organization header
Returns:
Dictionary of HTTP headers
"""
headers = {"Authorization": f"Token {token or self.cvat_token}"}
if with_organization:
headers["X-Organization"] = self.cvat_organization
return headers
def _handle_response_errors(
self, response: requests.Response | None, error_prefix: str
) -> None:
"""Handle common response errors with consistent logging.
Args:
response: HTTP response object (optional)
error_prefix: Error message prefix for logging
"""
if isinstance(response, requests.Response):
self.logger.error("%s: %d", error_prefix, response.status_code)
response_text = json.dumps(response.text, indent=2) if response.text else 'No response text'
self.logger.error("❌ Response text:\n%s", response_text)
def _make_request(
self,
method: str,
url: str,
headers: dict[str, str],
resource_name: str,
resource_id: int | None = None,
params: dict[str, Any] | None = None,
json_data: dict[str, Any] | None = None,
timeout: float | None = None,
response_model: type | None = None,
) -> Any:
"""Make HTTP request with standard error handling and exponential backoff retry.
Automatically retries on transient errors with exponential backoff:
- Network timeouts
- Connection errors
- HTTP 502 (Bad Gateway), 503 (Service Unavailable), 504 (Gateway Timeout)
Args:
method: HTTP method (GET, POST, PUT, PATCH, DELETE)
url: Request URL
headers: HTTP headers
resource_name: Resource name for logging
resource_id: Resource ID for logging (optional)
params: Query parameters (optional)
json_data: JSON body data (optional)
timeout: Request timeout in seconds (optional)
response_model: Pydantic model for response validation (optional)
Returns:
Response object (Pydantic model if response_model provided, else raw response)
Raises:
requests.RequestException: If request fails after all retries
"""
timeout = timeout or self.cvat_api_timeout
resource_desc = (
f"{resource_name} {resource_id}" if resource_id else resource_name
)
# Transient error status codes that should trigger retry
TRANSIENT_STATUS_CODES = {502, 503, 504}
last_exception = None
for attempt in range(self.max_retries + 1):
if attempt > 0:
# Calculate exponential backoff delay
delay = min(
self.initial_retry_delay * (2 ** (attempt - 1)),
self.max_retry_delay
)
self.logger.warning(
"⏳ Retry %d/%d for %s after %.1fs delay...",
attempt,
self.max_retries,
resource_desc,
delay
)
time.sleep(delay)
# No emoji for routine API requests (only for milestones)
self.logger.debug("%s request for %s (attempt %d)", method.upper(), resource_desc, attempt + 1)
try:
if method.upper() == "GET":
response = requests.get(
url, headers=headers, params=params, timeout=timeout
)
elif method.upper() == "POST":
response = requests.post(
url, headers=headers, json=json_data, timeout=timeout
)
elif method.upper() == "PUT":
response = requests.put(
url, headers=headers, json=json_data, timeout=timeout
)
elif method.upper() == "PATCH":
response = requests.patch(
url, headers=headers, json=json_data, timeout=timeout
)
elif method.upper() == "DELETE":
response = requests.delete(url, headers=headers, timeout=timeout)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
# Parse and validate response if model provided
if response_model and response.content:
response_data = response.json()
return response_model.model_validate(response_data)
return response
except requests.Timeout as e:
last_exception = e
self.logger.warning(
"⚠️ Request timeout for %s (attempt %d/%d)",
resource_desc,
attempt + 1,
self.max_retries + 1
)
# Continue to retry
except requests.ConnectionError as e:
last_exception = e
self.logger.warning(
"⚠️ Connection error for %s (attempt %d/%d)",
resource_desc,
attempt + 1,
self.max_retries + 1
)
# Continue to retry
except requests.HTTPError as e:
last_exception = e
# Only retry on transient HTTP errors
if hasattr(e, "response") and e.response.status_code in TRANSIENT_STATUS_CODES:
self.logger.warning(
"⚠️ Transient HTTP %d error for %s (attempt %d/%d)",
e.response.status_code,
resource_desc,
attempt + 1,
self.max_retries + 1
)
# Continue to retry
else:
# Non-transient error, don't retry
self._handle_response_errors(
e.response if hasattr(e, "response") else None,
f"Failed to {method.upper()} {resource_desc}",
)
raise
except requests.RequestException as e:
# Other request exceptions - don't retry
last_exception = e
self._handle_response_errors(
e.response if hasattr(e, "response") else None,
f"Failed to {method.upper()} {resource_desc}",
)
raise
# All retries exhausted
self.logger.error(
"❌ All %d retry attempts exhausted for %s",
self.max_retries + 1,
resource_desc
)
if last_exception:
self._handle_response_errors(
last_exception.response if hasattr(last_exception, "response") else None,
f"Failed to {method.upper()} {resource_desc} after {self.max_retries + 1} attempts",
)
raise last_exception
# Should never reach here, but just in case
raise RuntimeError(f"Request failed after {self.max_retries + 1} attempts")
|