File size: 6,023 Bytes
61ba51e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import json
import os
import tempfile
import unittest
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
register_cuda_ci(est_time=38, suite="stage-b-test-small-1-gpu")
register_amd_ci(est_time=38, suite="stage-b-test-small-1-gpu-amd")
class TestInputEmbeds(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--disable-radix", "--cuda-graph-max-bs", 4],
)
cls.texts = [
"The capital of France is",
"What is the best time of year to visit Japan for cherry blossoms?",
]
def generate_input_embeddings(self, text):
"""Generate input embeddings for a given text."""
input_ids = self.tokenizer(text, return_tensors="pt")["input_ids"]
embeddings = self.ref_model.get_input_embeddings()(input_ids)
return embeddings.squeeze().tolist() # Convert tensor to a list for API use
def send_request(self, payload):
"""Send a POST request to the /generate endpoint and return the response."""
response = requests.post(
self.base_url + "/generate",
json=payload,
timeout=30, # Set a reasonable timeout for the API request
)
if response.status_code == 200:
return response.json()
return {
"error": f"Request failed with status {response.status_code}: {response.text}"
}
def send_file_request(self, file_path):
"""Send a POST request to the /generate_from_file endpoint with a file."""
with open(file_path, "rb") as f:
response = requests.post(
self.base_url + "/generate_from_file",
files={"file": f},
timeout=30, # Set a reasonable timeout for the API request
)
if response.status_code == 200:
return response.json()
return {
"error": f"Request failed with status {response.status_code}: {response.text}"
}
def test_text_based_response(self):
"""Test and print API responses using text-based input."""
for text in self.texts:
payload = {
"model": self.model,
"text": text,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
response = self.send_request(payload)
print(
f"Text Input: {text}\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
)
def test_embedding_based_response(self):
"""Test and print API responses using input embeddings."""
for text in self.texts:
embeddings = self.generate_input_embeddings(text)
payload = {
"model": self.model,
"input_embeds": embeddings,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
response = self.send_request(payload)
print(
f"Embeddings Input (for text '{text}'):\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
)
def test_compare_text_vs_embedding(self):
"""Test and compare responses for text-based and embedding-based inputs."""
for text in self.texts:
# Text-based payload
text_payload = {
"model": self.model,
"text": text,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
# Embedding-based payload
embeddings = self.generate_input_embeddings(text)
embed_payload = {
"model": self.model,
"input_embeds": embeddings,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
# Get responses
text_response = self.send_request(text_payload)
embed_response = self.send_request(embed_payload)
# Print responses
print(
f"Text Input: {text}\nText-Based Response: {json.dumps(text_response, indent=2)}\n"
)
print(
f"Embeddings Input (for text '{text}'):\nEmbedding-Based Response: {json.dumps(embed_response, indent=2)}\n{'-' * 80}"
)
# This is flaky, so we skip this temporarily
# self.assertEqual(text_response["text"], embed_response["text"])
def test_generate_from_file(self):
"""Test the /generate_from_file endpoint using tokenized embeddings."""
for text in self.texts:
embeddings = self.generate_input_embeddings(text)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
json.dump(embeddings, tmp_file)
tmp_file_path = tmp_file.name
try:
response = self.send_file_request(tmp_file_path)
print(
f"Text Input: {text}\nResponse from /generate_from_file: {json.dumps(response, indent=2)}\n{'-' * 80}"
)
finally:
# Ensure the temporary file is deleted
os.remove(tmp_file_path)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
if __name__ == "__main__":
unittest.main()
|