Spaces:
Paused
Paused
| import time | |
| from pathlib import Path | |
| from backend.app.api.schemas import GenerateRequest, JobStatus | |
| from backend.app.jobs.manager import JobManager | |
| from backend.app.providers.factory import ProviderRegistry | |
| from backend.app.storage.history import PromptHistoryStore | |
| def test_registry_contains_localai() -> None: | |
| registry = ProviderRegistry() | |
| provider = registry.get("localai") | |
| assert provider.id == "localai" | |
| def test_job_lifecycle(tmp_path: Path, monkeypatch) -> None: | |
| monkeypatch.setattr("backend.app.jobs.manager.OUTPUT_DIR", tmp_path) | |
| manager = JobManager(ProviderRegistry(), PromptHistoryStore(), state_file=tmp_path / "jobs_state.json") | |
| req = GenerateRequest( | |
| prompt="a mountain", | |
| negative_prompt="", | |
| model="dummy", | |
| size="1024x1024", | |
| count=1, | |
| random_seed=True, | |
| steps=10, | |
| guidance=7.5, | |
| ) | |
| job_id = manager.submit(req) | |
| deadline = time.time() + 10 | |
| while time.time() < deadline: | |
| state = manager.get(job_id) | |
| assert state is not None | |
| if state.status in {JobStatus.DONE, JobStatus.ERROR}: | |
| break | |
| time.sleep(0.1) | |
| state = manager.get(job_id) | |
| assert state is not None | |
| assert state.status == JobStatus.DONE | |
| assert state.progress == 100 | |
| assert state.image_paths | |
| def test_job_cancel_from_queue(tmp_path: Path) -> None: | |
| manager = JobManager(ProviderRegistry(), PromptHistoryStore(), state_file=tmp_path / "jobs_state.json") | |
| req = GenerateRequest(prompt="cancel me", model="dummy") | |
| job_id = manager.submit(req) | |
| assert manager.cancel(job_id) | |
| state = manager.get(job_id) | |
| assert state is not None | |
| assert state.cancel_requested | |
| def test_job_stats_shape(tmp_path: Path) -> None: | |
| manager = JobManager(ProviderRegistry(), PromptHistoryStore(), state_file=tmp_path / "jobs_state.json") | |
| stats = manager.stats() | |
| assert set(stats.keys()) == { | |
| "queued", | |
| "running", | |
| "done", | |
| "error", | |
| "cancelled", | |
| "total", | |
| "last_24h", | |
| } | |
| def test_job_falls_back_when_primary_provider_fails(tmp_path: Path, monkeypatch) -> None: | |
| monkeypatch.setattr("backend.app.jobs.manager.OUTPUT_DIR", tmp_path) | |
| monkeypatch.setenv("IMAGEFORGE_ENABLE_AUTO_FALLBACK", "1") | |
| monkeypatch.setenv("IMAGEFORGE_FALLBACK_MODELS", "localai,dummy") | |
| monkeypatch.setenv("IMAGEFORGE_FALLBACK_MAX_STEPS", "12") | |
| class FailingProvider: | |
| id = "a1111" | |
| name = "Failing A1111" | |
| description = "" | |
| def is_available(self) -> bool: | |
| return True | |
| def generate(self, request, output_dir, progress, is_cancelled): # noqa: ANN001 | |
| raise RuntimeError("primary failure") | |
| class WorkingProvider: | |
| id = "localai" | |
| name = "Working LocalAI" | |
| description = "" | |
| def __init__(self) -> None: | |
| self.last_steps = None | |
| def is_available(self) -> bool: | |
| return True | |
| def generate(self, request, output_dir, progress, is_cancelled): # noqa: ANN001 | |
| self.last_steps = request.steps | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| path = output_dir / "image_01.png" | |
| path.write_bytes(b"fallback") | |
| progress(100, "ok") | |
| from backend.app.providers.interface import ProviderResult | |
| return ProviderResult(image_paths=[path]) | |
| class FakeRegistry: | |
| def __init__(self) -> None: | |
| self.working = WorkingProvider() | |
| self.providers = { | |
| "a1111": FailingProvider(), | |
| "localai": self.working, | |
| "dummy": self.working, | |
| } | |
| def get(self, provider_id: str): | |
| return self.providers[provider_id] | |
| def list(self): | |
| return list(self.providers.values()) | |
| registry = FakeRegistry() | |
| manager = JobManager(registry, PromptHistoryStore(), state_file=tmp_path / "jobs_state.json") | |
| req = GenerateRequest(prompt="fallback", model="a1111", steps=30, guidance=9.0) | |
| job_id = manager.submit(req) | |
| deadline = time.time() + 10 | |
| while time.time() < deadline: | |
| state = manager.get(job_id) | |
| assert state is not None | |
| if state.status in {JobStatus.DONE, JobStatus.ERROR}: | |
| break | |
| time.sleep(0.1) | |
| state = manager.get(job_id) | |
| assert state is not None | |
| assert state.status == JobStatus.DONE | |
| assert state.image_paths | |
| assert registry.working.last_steps == 12 | |