Spaces:
Paused
Paused
| from __future__ import annotations | |
| import unittest | |
| from unittest import mock | |
| from fastapi import FastAPI | |
| from fastapi.testclient import TestClient | |
| import api.image_tasks as image_tasks_module | |
| AUTH_HEADERS = {"Authorization": "Bearer chatgpt2api"} | |
| class FakeImageTaskService: | |
| def __init__(self): | |
| self.generation_calls = [] | |
| self.edit_calls = [] | |
| def submit_generation(self, identity, **kwargs): | |
| self.generation_calls.append((identity, kwargs)) | |
| return { | |
| "id": kwargs["client_task_id"], | |
| "status": "success", | |
| "mode": "generate", | |
| "created_at": "2026-01-01 00:00:00", | |
| "updated_at": "2026-01-01 00:00:00", | |
| "data": [{"url": f"{kwargs['base_url']}/images/fake.png"}], | |
| } | |
| def submit_edit(self, identity, **kwargs): | |
| self.edit_calls.append((identity, kwargs)) | |
| return { | |
| "id": kwargs["client_task_id"], | |
| "status": "queued", | |
| "mode": "edit", | |
| "created_at": "2026-01-01 00:00:00", | |
| "updated_at": "2026-01-01 00:00:00", | |
| } | |
| def list_tasks(self, _identity, ids): | |
| return { | |
| "items": [ | |
| { | |
| "id": task_id, | |
| "status": "success", | |
| "mode": "generate", | |
| "created_at": "2026-01-01 00:00:00", | |
| "updated_at": "2026-01-01 00:00:00", | |
| "data": [{"url": "http://testserver/images/fake.png"}], | |
| } | |
| for task_id in ids | |
| if task_id != "missing" | |
| ], | |
| "missing_ids": [task_id for task_id in ids if task_id == "missing"], | |
| } | |
| class ImageTasksApiTests(unittest.TestCase): | |
| def setUp(self): | |
| self.fake_service = FakeImageTaskService() | |
| self.service_patcher = mock.patch.object(image_tasks_module, "image_task_service", self.fake_service) | |
| self.service_patcher.start() | |
| self.addCleanup(self.service_patcher.stop) | |
| app = FastAPI() | |
| app.include_router(image_tasks_module.create_router()) | |
| self.client = TestClient(app) | |
| def test_create_generation_task(self): | |
| response = self.client.post( | |
| "/api/image-tasks/generations", | |
| headers=AUTH_HEADERS, | |
| json={"client_task_id": "task-1", "prompt": "cat", "model": "gpt-image-2"}, | |
| ) | |
| self.assertEqual(response.status_code, 200, response.text) | |
| payload = response.json() | |
| self.assertEqual(payload["id"], "task-1") | |
| self.assertEqual(payload["status"], "success") | |
| self.assertEqual(len(self.fake_service.generation_calls), 1) | |
| def test_create_edit_task_accepts_multiple_images(self): | |
| response = self.client.post( | |
| "/api/image-tasks/edits", | |
| headers=AUTH_HEADERS, | |
| data={"client_task_id": "edit-1", "prompt": "edit", "model": "gpt-image-2"}, | |
| files=[ | |
| ("image", ("one.png", b"one", "image/png")), | |
| ("image", ("two.png", b"two", "image/png")), | |
| ], | |
| ) | |
| self.assertEqual(response.status_code, 200, response.text) | |
| self.assertEqual(response.json()["id"], "edit-1") | |
| self.assertEqual(len(self.fake_service.edit_calls), 1) | |
| images = self.fake_service.edit_calls[0][1]["images"] | |
| self.assertEqual(len(images), 2) | |
| def test_list_tasks_reports_missing_ids(self): | |
| response = self.client.get("/api/image-tasks?ids=task-1,missing", headers=AUTH_HEADERS) | |
| self.assertEqual(response.status_code, 200, response.text) | |
| payload = response.json() | |
| self.assertEqual([item["id"] for item in payload["items"]], ["task-1"]) | |
| self.assertEqual(payload["missing_ids"], ["missing"]) | |
| if __name__ == "__main__": | |
| unittest.main() | |