import os import base64 import io import tempfile from PIL import Image import pytest import server def _make_png_data(): buf = io.BytesIO() img = Image.new('RGB', (32, 32), color='red') img.save(buf, format='PNG') b = buf.getvalue() b64 = base64.b64encode(b).decode('ascii') data_uri = f"data:image/png;base64,{b64}" return data_uri, b def test_save_data_uri(): data_uri, raw = _make_png_data() path = server._save_img2img_image_to_file(data_uri, max_size_bytes=10 * 1024 * 1024) assert path and os.path.exists(path) with open(path, 'rb') as f: content = f.read() assert content.startswith(b"\x89PNG\r\n\x1a\n") os.remove(path) def test_save_bare_base64(): data_uri, raw = _make_png_data() b64 = data_uri.split(',', 1)[1] path = server._save_img2img_image_to_file(b64, max_size_bytes=10 * 1024 * 1024) assert path and os.path.exists(path) with open(path, 'rb') as f: content = f.read() assert content.startswith(b"\x89PNG\r\n\x1a\n") os.remove(path) @pytest.mark.asyncio async def test_generate_endpoint_converts(monkeypatch, async_server_client): data_uri, raw = _make_png_data() async def fake_enqueue(pending): # The pending request should have had its img2img_image converted to a real file path val = pending.req.img2img_image assert val and os.path.exists(val) with open(val, 'rb') as f: d = f.read() assert d.startswith(b"\x89PNG\r\n\x1a\n") return {'image': 'data:image/png;base64,' + base64.b64encode(raw).decode('ascii')} monkeypatch.setattr(server._generation_buffer, 'enqueue', fake_enqueue) payload = { 'prompt': 'test', 'width': 512, 'height': 512, 'num_images': 1, 'img2img_mode': True, 'img2img_image': data_uri, } res = await async_server_client.post('/api/generate', json=payload) assert res.status_code == 200 assert 'image' in res.json() def test_save_too_large(): # Create a base64 string that decodes to more than 10MB big_bytes = b'A' * (11 * 1024 * 1024) big_b64 = base64.b64encode(big_bytes).decode('ascii') from fastapi import HTTPException try: server._save_img2img_image_to_file(big_b64, max_size_bytes=10 * 1024 * 1024) raised = False except Exception as e: raised = True assert isinstance(e, HTTPException) assert e.status_code == 413 assert raised