|
|
import pytest |
|
|
from utils import * |
|
|
import base64 |
|
|
import requests |
|
|
|
|
|
server: ServerProcess |
|
|
|
|
|
def get_img_url(id: str) -> str: |
|
|
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" |
|
|
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" |
|
|
if id == "IMG_URL_0": |
|
|
return IMG_URL_0 |
|
|
elif id == "IMG_URL_1": |
|
|
return IMG_URL_1 |
|
|
elif id == "IMG_BASE64_URI_0": |
|
|
response = requests.get(IMG_URL_0) |
|
|
response.raise_for_status() |
|
|
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") |
|
|
elif id == "IMG_BASE64_0": |
|
|
response = requests.get(IMG_URL_0) |
|
|
response.raise_for_status() |
|
|
return base64.b64encode(response.content).decode("utf-8") |
|
|
elif id == "IMG_BASE64_URI_1": |
|
|
response = requests.get(IMG_URL_1) |
|
|
response.raise_for_status() |
|
|
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") |
|
|
elif id == "IMG_BASE64_1": |
|
|
response = requests.get(IMG_URL_1) |
|
|
response.raise_for_status() |
|
|
return base64.b64encode(response.content).decode("utf-8") |
|
|
else: |
|
|
return id |
|
|
|
|
|
JSON_MULTIMODAL_KEY = "multimodal_data" |
|
|
JSON_PROMPT_STRING_KEY = "prompt_string" |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
|
def create_server(): |
|
|
global server |
|
|
server = ServerPreset.tinygemma3() |
|
|
|
|
|
def test_models_supports_multimodal_capability(): |
|
|
global server |
|
|
server.start() |
|
|
res = server.make_request("GET", "/models", data={}) |
|
|
assert res.status_code == 200 |
|
|
model_info = res.body["models"][0] |
|
|
print(model_info) |
|
|
assert "completion" in model_info["capabilities"] |
|
|
assert "multimodal" in model_info["capabilities"] |
|
|
|
|
|
def test_v1_models_supports_multimodal_capability(): |
|
|
global server |
|
|
server.start() |
|
|
res = server.make_request("GET", "/v1/models", data={}) |
|
|
assert res.status_code == 200 |
|
|
model_info = res.body["models"][0] |
|
|
print(model_info) |
|
|
assert "completion" in model_info["capabilities"] |
|
|
assert "multimodal" in model_info["capabilities"] |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"prompt, image_url, success, re_content", |
|
|
[ |
|
|
|
|
|
("What is this:\n", "IMG_URL_0", True, "(cat)+"), |
|
|
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), |
|
|
("What is this:\n", "IMG_URL_1", True, "(frog)+"), |
|
|
("Test test\n", "IMG_URL_1", True, "(frog)+"), |
|
|
("What is this:\n", "malformed", False, None), |
|
|
("What is this:\n", "https://google.com/404", False, None), |
|
|
("What is this:\n", "https://ggml.ai", False, None), |
|
|
|
|
|
] |
|
|
) |
|
|
def test_vision_chat_completion(prompt, image_url, success, re_content): |
|
|
global server |
|
|
server.start() |
|
|
res = server.make_request("POST", "/chat/completions", data={ |
|
|
"temperature": 0.0, |
|
|
"top_k": 1, |
|
|
"messages": [ |
|
|
{"role": "user", "content": [ |
|
|
{"type": "text", "text": prompt}, |
|
|
{"type": "image_url", "image_url": { |
|
|
"url": get_img_url(image_url), |
|
|
}}, |
|
|
]}, |
|
|
], |
|
|
}) |
|
|
if success: |
|
|
assert res.status_code == 200 |
|
|
choice = res.body["choices"][0] |
|
|
assert "assistant" == choice["message"]["role"] |
|
|
assert match_regex(re_content, choice["message"]["content"]) |
|
|
else: |
|
|
assert res.status_code != 200 |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"prompt, image_data, success, re_content", |
|
|
[ |
|
|
|
|
|
("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"), |
|
|
("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"), |
|
|
("What is this: <__media__>\n", "malformed", False, None), |
|
|
("What is this:\n", "", False, None), |
|
|
] |
|
|
) |
|
|
def test_vision_completion(prompt, image_data, success, re_content): |
|
|
global server |
|
|
server.start() |
|
|
res = server.make_request("POST", "/completions", data={ |
|
|
"temperature": 0.0, |
|
|
"top_k": 1, |
|
|
"prompt": { |
|
|
JSON_PROMPT_STRING_KEY: prompt, |
|
|
JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ], |
|
|
}, |
|
|
}) |
|
|
if success: |
|
|
assert res.status_code == 200 |
|
|
content = res.body["content"] |
|
|
assert match_regex(re_content, content) |
|
|
else: |
|
|
assert res.status_code != 200 |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"prompt, image_data, success", |
|
|
[ |
|
|
|
|
|
("What is this: <__media__>\n", "IMG_BASE64_0", True), |
|
|
("What is this: <__media__>\n", "IMG_BASE64_1", True), |
|
|
("What is this: <__media__>\n", "malformed", False), |
|
|
("What is this:\n", "base64", False), |
|
|
] |
|
|
) |
|
|
def test_vision_embeddings(prompt, image_data, success): |
|
|
global server |
|
|
server.server_embeddings = True |
|
|
server.n_batch = 512 |
|
|
server.start() |
|
|
image_data = get_img_url(image_data) |
|
|
res = server.make_request("POST", "/embeddings", data={ |
|
|
"content": [ |
|
|
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, |
|
|
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, |
|
|
{ JSON_PROMPT_STRING_KEY: prompt, }, |
|
|
], |
|
|
}) |
|
|
if success: |
|
|
assert res.status_code == 200 |
|
|
content = res.body |
|
|
|
|
|
assert content[0]['embedding'] == content[1]['embedding'] |
|
|
|
|
|
assert content[0]['embedding'] != content[2]['embedding'] |
|
|
else: |
|
|
assert res.status_code != 200 |
|
|
|