| | import pytest
|
| | from utils import *
|
| | import base64
|
| | import requests
|
| |
|
| | server: ServerProcess
|
| |
|
| | def get_img_url(id: str) -> str:
|
| | IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
|
| | IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
|
| | if id == "IMG_URL_0":
|
| | return IMG_URL_0
|
| | elif id == "IMG_URL_1":
|
| | return IMG_URL_1
|
| | elif id == "IMG_BASE64_URI_0":
|
| | response = requests.get(IMG_URL_0)
|
| | response.raise_for_status()
|
| | return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
| | elif id == "IMG_BASE64_0":
|
| | response = requests.get(IMG_URL_0)
|
| | response.raise_for_status()
|
| | return base64.b64encode(response.content).decode("utf-8")
|
| | elif id == "IMG_BASE64_URI_1":
|
| | response = requests.get(IMG_URL_1)
|
| | response.raise_for_status()
|
| | return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
| | elif id == "IMG_BASE64_1":
|
| | response = requests.get(IMG_URL_1)
|
| | response.raise_for_status()
|
| | return base64.b64encode(response.content).decode("utf-8")
|
| | else:
|
| | return id
|
| |
|
| | JSON_MULTIMODAL_KEY = "multimodal_data"
|
| | JSON_PROMPT_STRING_KEY = "prompt_string"
|
| |
|
| | @pytest.fixture(autouse=True)
|
| | def create_server():
|
| | global server
|
| | server = ServerPreset.tinygemma3()
|
| |
|
| | def test_models_supports_multimodal_capability():
|
| | global server
|
| | server.start()
|
| | res = server.make_request("GET", "/models", data={})
|
| | assert res.status_code == 200
|
| | model_info = res.body["models"][0]
|
| | print(model_info)
|
| | assert "completion" in model_info["capabilities"]
|
| | assert "multimodal" in model_info["capabilities"]
|
| |
|
| | def test_v1_models_supports_multimodal_capability():
|
| | global server
|
| | server.start()
|
| | res = server.make_request("GET", "/v1/models", data={})
|
| | assert res.status_code == 200
|
| | model_info = res.body["models"][0]
|
| | print(model_info)
|
| | assert "completion" in model_info["capabilities"]
|
| | assert "multimodal" in model_info["capabilities"]
|
| |
|
| | @pytest.mark.parametrize(
|
| | "prompt, image_url, success, re_content",
|
| | [
|
| |
|
| | ("What is this:\n", "IMG_URL_0", True, "(cat)+"),
|
| | ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"),
|
| | ("What is this:\n", "IMG_URL_1", True, "(frog)+"),
|
| | ("Test test\n", "IMG_URL_1", True, "(frog)+"),
|
| | ("What is this:\n", "malformed", False, None),
|
| | ("What is this:\n", "https://google.com/404", False, None),
|
| | ("What is this:\n", "https://ggml.ai", False, None),
|
| |
|
| | ]
|
| | )
|
| | def test_vision_chat_completion(prompt, image_url, success, re_content):
|
| | global server
|
| | server.start()
|
| | res = server.make_request("POST", "/chat/completions", data={
|
| | "temperature": 0.0,
|
| | "top_k": 1,
|
| | "messages": [
|
| | {"role": "user", "content": [
|
| | {"type": "text", "text": prompt},
|
| | {"type": "image_url", "image_url": {
|
| | "url": get_img_url(image_url),
|
| | }},
|
| | ]},
|
| | ],
|
| | })
|
| | if success:
|
| | assert res.status_code == 200
|
| | choice = res.body["choices"][0]
|
| | assert "assistant" == choice["message"]["role"]
|
| | assert match_regex(re_content, choice["message"]["content"])
|
| | else:
|
| | assert res.status_code != 200
|
| |
|
| |
|
| | @pytest.mark.parametrize(
|
| | "prompt, image_data, success, re_content",
|
| | [
|
| |
|
| | ("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"),
|
| | ("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"),
|
| | ("What is this: <__media__>\n", "malformed", False, None),
|
| | ("What is this:\n", "", False, None),
|
| | ]
|
| | )
|
| | def test_vision_completion(prompt, image_data, success, re_content):
|
| | global server
|
| | server.start()
|
| | res = server.make_request("POST", "/completions", data={
|
| | "temperature": 0.0,
|
| | "top_k": 1,
|
| | "prompt": {
|
| | JSON_PROMPT_STRING_KEY: prompt,
|
| | JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
|
| | },
|
| | })
|
| | if success:
|
| | assert res.status_code == 200
|
| | content = res.body["content"]
|
| | assert match_regex(re_content, content)
|
| | else:
|
| | assert res.status_code != 200
|
| |
|
| |
|
| | @pytest.mark.parametrize(
|
| | "prompt, image_data, success",
|
| | [
|
| |
|
| | ("What is this: <__media__>\n", "IMG_BASE64_0", True),
|
| | ("What is this: <__media__>\n", "IMG_BASE64_1", True),
|
| | ("What is this: <__media__>\n", "malformed", False),
|
| | ("What is this:\n", "base64", False),
|
| | ]
|
| | )
|
| | def test_vision_embeddings(prompt, image_data, success):
|
| | global server
|
| | server.server_embeddings = True
|
| | server.n_batch = 512
|
| | server.start()
|
| | image_data = get_img_url(image_data)
|
| | res = server.make_request("POST", "/embeddings", data={
|
| | "content": [
|
| | { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
| | { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
| | { JSON_PROMPT_STRING_KEY: prompt, },
|
| | ],
|
| | })
|
| | if success:
|
| | assert res.status_code == 200
|
| | content = res.body
|
| |
|
| | assert content[0]['embedding'] == content[1]['embedding']
|
| |
|
| | assert content[0]['embedding'] != content[2]['embedding']
|
| | else:
|
| | assert res.status_code != 200
|
| |
|