|
|
from loguru import logger |
|
|
import re |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
from google import genai |
|
|
from openai import OpenAI |
|
|
from together import Together |
|
|
import anthropic |
|
|
from anthropic.types import ThinkingBlock, TextBlock |
|
|
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming |
|
|
from anthropic.types.messages.batch_create_params import Request |
|
|
import time |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor |
|
|
import base64 |
|
|
import requests |
|
|
import json |
|
|
|
|
|
import tempfile |
|
|
|
|
|
|
|
|
def encode_image(image_path): |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
|
|
class APIQuery: |
|
|
def __init__(self, model, |
|
|
timeout=6000, |
|
|
max_tokens=None, |
|
|
api='openai', |
|
|
max_retries=50, |
|
|
concurrent_requests=30, |
|
|
is_chat=True, |
|
|
no_system_messages=False, |
|
|
read_cost=1, |
|
|
write_cost=1, |
|
|
sleep_on_error=60, |
|
|
sleep_after_request=0.1, |
|
|
throw_error_on_failure=False, |
|
|
max_tokens_param="max_tokens", |
|
|
reasoning_effort=None, |
|
|
batch_processing=False, |
|
|
openai_responses=False, |
|
|
**kwargs): |
|
|
|
|
|
if "think" in model and api == "google": |
|
|
logger.info("Google Think model does not allow chat.") |
|
|
is_chat = False |
|
|
max_tokens_param = "max_output_tokens" |
|
|
if ("o1" in model or "o3" in model) and api == "openai": |
|
|
logger.info("Not using system messages for o1/o3 model.") |
|
|
no_system_messages = True |
|
|
max_tokens_param = "max_completion_tokens" |
|
|
if "--" in model: |
|
|
model, reasoning_effort = model.split("--") |
|
|
logger.info(f"Model: {model}, Reasoning effort: {reasoning_effort}") |
|
|
if api == "anthropic" and "claude-3-7" not in model: |
|
|
logger.info("Setting max tokens to 8192 for Anthropic API.") |
|
|
max_tokens = min(8192, max_tokens) |
|
|
if api == "deepseek": |
|
|
logger.info("Setting max tokens to 8192 for DeepSeek API.") |
|
|
max_tokens = min(8192, max_tokens) |
|
|
if api not in ["anthropic", "openai"] and batch_processing: |
|
|
logger.warning("Batch processing is only supported for the Anthropic API and OpenAI API.") |
|
|
batch_processing = False |
|
|
if openai_responses: |
|
|
max_tokens_param = "max_output_tokens" |
|
|
|
|
|
self.kwarg_remover(api, model, kwargs) |
|
|
|
|
|
self.model = model |
|
|
self.kwargs = kwargs |
|
|
self.kwargs[max_tokens_param] = max_tokens |
|
|
self.timeout = timeout |
|
|
self.max_retries = max_retries |
|
|
self.throw_error_on_failure = throw_error_on_failure |
|
|
self.concurrent_requests = concurrent_requests |
|
|
self.is_chat = is_chat |
|
|
self.no_system_messages = no_system_messages |
|
|
self.sleep_on_error = sleep_on_error |
|
|
self.sleep_after_request = sleep_after_request |
|
|
self.read_cost = read_cost |
|
|
self.write_cost = write_cost |
|
|
self.batch_processing = batch_processing |
|
|
self.openai_responses = openai_responses |
|
|
|
|
|
if max_tokens is not None: |
|
|
self.max_tokens_param = max_tokens_param |
|
|
if reasoning_effort is not None: |
|
|
if not self.openai_responses: |
|
|
self.kwargs["reasoning_effort"] = reasoning_effort |
|
|
else: |
|
|
self.kwargs["reasoning"] = {"effort": reasoning_effort} |
|
|
|
|
|
self.api = api |
|
|
self.api_key = None |
|
|
self.base_url = None |
|
|
|
|
|
self.initialize_api_keys() |
|
|
|
|
|
def kwarg_remover(self, api, model, kwargs): |
|
|
if (api == "anthropic" and "claude-3-7" in model) or (("o1" in model or "o3" in model) and api == "openai"): |
|
|
for kwarg_to_remove in ["top_p", "top_k", "temperature"]: |
|
|
if kwarg_to_remove in kwargs: |
|
|
logger.info(f"Removing {kwarg_to_remove} parameter for {model} model.") |
|
|
del kwargs[kwarg_to_remove] |
|
|
|
|
|
def initialize_api_keys(self): |
|
|
if self.api == "openai": |
|
|
self.api_key = os.getenv("OPENAI_API_KEY") |
|
|
elif self.api == "together": |
|
|
self.api_key = os.getenv("TOGETHER_API_KEY") |
|
|
self.base_url = "https://api.together.xyz/v1" |
|
|
elif self.api == "google": |
|
|
self.api_key = os.getenv("GOOGLE_API_KEY") |
|
|
if not "think" in self.model: |
|
|
self.api = "openai" |
|
|
self.base_url = "https://generativelanguage.googleapis.com/v1beta/openai/" |
|
|
elif self.api == "anthropic": |
|
|
self.api_key = os.getenv("ANTHROPIC_API_KEY") |
|
|
elif self.api == "hyperbolic": |
|
|
self.api_key = os.getenv("HYPERBOLIC_API_KEY") |
|
|
self.base_url = "https://api.hyperbolic.xyz/v1" |
|
|
self.api = "openai" |
|
|
elif self.api == 'sambanova': |
|
|
self.api_key = os.getenv("SAMBA_API_KEY") |
|
|
self.base_url = "https://api.sambanova.ai/v1" |
|
|
self.api = "openai" |
|
|
elif self.api == "deepseek": |
|
|
self.api_key = os.getenv("DEEPSEEK_API_KEY") |
|
|
self.base_url = "https://api.deepseek.com" |
|
|
self.api = "openai" |
|
|
elif self.api == "openrouter": |
|
|
self.api_key = os.getenv("OPENROUTER_API_KEY") |
|
|
self.base_url = "https://openrouter.ai/api/v1" |
|
|
self.api = "openai" |
|
|
elif self.api == "fireworks": |
|
|
self.api_key = os.getenv("FIREWORKS_API_KEY") |
|
|
self.base_url = "https://api.fireworks.ai/inference/v1" |
|
|
self.api = "openai" |
|
|
elif self.api == "vllm": |
|
|
self.api_key = "token-abc123" |
|
|
self.api = "openai" |
|
|
self.base_url = f"http://localhost:8000/v1" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
raise ValueError(f"API {self.api} not supported.") |
|
|
assert self.api_key is not None, f"API key not found." |
|
|
|
|
|
def prepare_query(self, query): |
|
|
query, image_path = query |
|
|
if not self.is_chat: |
|
|
output_query = query[0]["content"] |
|
|
for message in query: |
|
|
output_query += f"\n\n{'=' * 20}{message['role']}{'=' * 20}\n\n{message['content']}" |
|
|
return output_query, image_path |
|
|
elif self.no_system_messages: |
|
|
|
|
|
query = [{ |
|
|
"role": message["role"] if message["role"] != "system" else "user", |
|
|
"content": message["content"] |
|
|
} for message in query] |
|
|
return query, image_path |
|
|
|
|
|
def get_cost(self, response): |
|
|
cost = response["input_tokens"] * self.read_cost + response["output_tokens"] * self.write_cost |
|
|
return cost / (10 ** 6) |
|
|
|
|
|
def run_queries(self, queries): |
|
|
queries_actual = [] |
|
|
for query in queries: |
|
|
if not isinstance(query, tuple): |
|
|
queries_actual.append((query, None)) |
|
|
else: |
|
|
queries_actual.append(query) |
|
|
if self.api == "vllm": |
|
|
while True: |
|
|
try: |
|
|
response = requests.get(f"{self.base_url}", timeout=1) |
|
|
if response.status_code == 401: |
|
|
break |
|
|
except Exception: |
|
|
pass |
|
|
time.sleep(5) |
|
|
logger.info("Waiting for VLLM server to start...") |
|
|
logger.info("VLLM server started.") |
|
|
|
|
|
logger.info(f"Running {len(queries_actual)} queries.") |
|
|
|
|
|
if self.batch_processing: |
|
|
if self.api == "openai": |
|
|
processed_results = self.openai_batch_processing(queries_actual) |
|
|
else: |
|
|
processed_results = self.anthropic_batch_processing(queries_actual) |
|
|
for idx, result in enumerate(processed_results): |
|
|
detailed_cost = { |
|
|
"cost": self.get_cost(result), |
|
|
"input_tokens": result["input_tokens"], |
|
|
"output_tokens": result["output_tokens"], |
|
|
} |
|
|
yield idx, result["output"], detailed_cost |
|
|
else: |
|
|
with ThreadPoolExecutor(max_workers=self.concurrent_requests) as executor: |
|
|
future_to_index = { |
|
|
executor.submit(self.run_query_with_retry, query): i |
|
|
for i, query in enumerate(queries_actual) |
|
|
} |
|
|
for future in tqdm(as_completed(future_to_index), total=len(future_to_index)): |
|
|
idx = future_to_index[future] |
|
|
result = future.result() |
|
|
detailed_cost = { |
|
|
"cost": self.get_cost(result), |
|
|
"input_tokens": result["input_tokens"], |
|
|
"output_tokens": result["output_tokens"], |
|
|
} |
|
|
yield idx, result["output"], detailed_cost |
|
|
|
|
|
def run_query_with_retry(self, query): |
|
|
i = 0 |
|
|
while i < self.max_retries: |
|
|
try: |
|
|
output = self.run_query(query) |
|
|
time.sleep(self.sleep_after_request) |
|
|
return output |
|
|
except Exception as e: |
|
|
logger.error(f"Error: {e}") |
|
|
time.sleep(self.sleep_on_error) |
|
|
|
|
|
if "rate limit" not in str(e).lower() and "429" not in str(e): |
|
|
i += 1 |
|
|
continue |
|
|
if self.throw_error_on_failure: |
|
|
raise ValueError("Max retries reached.") |
|
|
else: |
|
|
return { |
|
|
"output": "", |
|
|
"input_tokens": 0, |
|
|
"output_tokens": 0, |
|
|
} |
|
|
|
|
|
def run_query(self, query): |
|
|
query = self.prepare_query(query) |
|
|
if self.api == "openai": |
|
|
return self.openai_query(query) |
|
|
elif self.api == "together": |
|
|
return self.together_query(query) |
|
|
elif self.api == "google": |
|
|
return self.google_query(query) |
|
|
elif self.api == "anthropic": |
|
|
return self.anthropic_query(query) |
|
|
|
|
|
def postprocess_anthropic_result(self, result): |
|
|
output_text = "" |
|
|
|
|
|
for content in result.content: |
|
|
if isinstance(content, ThinkingBlock): |
|
|
output_text += "<think>\n" + content.thinking + "</think>\n\n" |
|
|
elif isinstance(content, TextBlock): |
|
|
output_text += content.text |
|
|
break |
|
|
return { |
|
|
"output": output_text, |
|
|
"input_tokens": result.usage.input_tokens, |
|
|
"output_tokens": result.usage.output_tokens, |
|
|
} |
|
|
|
|
|
def anthropic_batch_processing(self, queries, error_repetition=0): |
|
|
if error_repetition >= self.max_retries: |
|
|
return [ |
|
|
{ |
|
|
"output": "", |
|
|
"input_tokens": 0, |
|
|
"output_tokens": 0, |
|
|
} for _ in range(len(queries)) |
|
|
] |
|
|
|
|
|
text_queries = [query[0] for query in queries] |
|
|
client = anthropic.Anthropic( |
|
|
api_key=self.api_key, |
|
|
max_retries=0, |
|
|
) |
|
|
|
|
|
requests = [] |
|
|
|
|
|
for i, text_query in enumerate(text_queries): |
|
|
kwargs_here = self.kwargs.copy() |
|
|
if text_query[0]["role"] == "system": |
|
|
kwargs_here["system"] = text_query[0]["content"] |
|
|
text_query = text_query[1:] |
|
|
|
|
|
request = Request( |
|
|
custom_id=f"apiquery-{i}", |
|
|
params=MessageCreateParamsNonStreaming( |
|
|
model=self.model, |
|
|
messages=text_query, |
|
|
**kwargs_here |
|
|
) |
|
|
) |
|
|
requests.append(request) |
|
|
|
|
|
message_batch = client.messages.batches.create(requests=requests) |
|
|
|
|
|
logger.info(f"Running {len(queries)} queries with batch ID {message_batch.id}") |
|
|
|
|
|
current_request_counts = dict(message_batch.request_counts) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
message_batch = client.messages.batches.retrieve( |
|
|
message_batch_id=message_batch.id, |
|
|
) |
|
|
except: |
|
|
logger.warning(f"Error connecting to Anthropic. Retrying in 10s.") |
|
|
pass |
|
|
if any([current_request_counts[key] != dict(message_batch.request_counts)[key] for key in current_request_counts]): |
|
|
current_request_counts = dict(message_batch.request_counts) |
|
|
error_sum = sum([current_request_counts[key] for key in current_request_counts if "succeeded" != key]) |
|
|
logger.info(f"Succeeded Requests Progress: {current_request_counts['succeeded']}/{len(queries)}. Errors: {error_sum}") |
|
|
if message_batch.processing_status == "ended": |
|
|
break |
|
|
time.sleep(10) |
|
|
|
|
|
outputs = [] |
|
|
repeat_indices = [] |
|
|
|
|
|
while True: |
|
|
try: |
|
|
results = client.messages.batches.results( |
|
|
message_batch_id=message_batch.id, |
|
|
) |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"Error connecting to Anthropic: {e}. Retrying in 10 seconds.") |
|
|
time.sleep(10) |
|
|
|
|
|
for i, result in enumerate(results): |
|
|
if result.result.type == "succeeded": |
|
|
outputs.append(self.postprocess_anthropic_result(result.result.message)) |
|
|
else: |
|
|
outputs.append(None) |
|
|
repeat_indices.append(i) |
|
|
if result.result.type == "errored": |
|
|
logger.error(result.result.error) |
|
|
|
|
|
if len(repeat_indices) > 0: |
|
|
logger.info(f"Repeating {len(repeat_indices)} queries.") |
|
|
repeat_queries = [queries[i] for i in repeat_indices] |
|
|
repeat_outputs = self.anthropic_batch_processing(repeat_queries, error_repetition + 1) |
|
|
for i, output in zip(repeat_indices, repeat_outputs): |
|
|
outputs[i] = output |
|
|
|
|
|
return outputs |
|
|
|
|
|
def anthropic_query(self, query): |
|
|
query, image_path = query |
|
|
client = anthropic.Anthropic( |
|
|
api_key=self.api_key, |
|
|
max_retries=0, |
|
|
timeout=self.timeout, |
|
|
) |
|
|
system_message = anthropic.NOT_GIVEN |
|
|
if query[0]["role"] == "system": |
|
|
system_message = query[0]["content"] |
|
|
query = query[1:] |
|
|
result = client.messages.create( |
|
|
model=self.model, |
|
|
messages=query, |
|
|
system=system_message, |
|
|
**self.kwargs |
|
|
) |
|
|
|
|
|
return self.postprocess_anthropic_result(result) |
|
|
|
|
|
def google_query(self, query): |
|
|
client = genai.Client(api_key=self.api_key, http_options={'api_version':'v1alpha'}) |
|
|
query, image_path = query |
|
|
if image_path is not None: |
|
|
file = client.files.upload(file=image_path) |
|
|
query = [query, file] |
|
|
|
|
|
|
|
|
response = client.models.generate_content( |
|
|
model=self.model, |
|
|
contents=query, |
|
|
|
|
|
) |
|
|
return { |
|
|
"output": "\n\n".join([response.candidates[0].content.parts[i].text |
|
|
for i in range(len(response.candidates[0].content.parts))]), |
|
|
"input_tokens": response.usage_metadata.prompt_token_count, |
|
|
"output_tokens": response.usage_metadata.candidates_token_count, |
|
|
} |
|
|
|
|
|
def together_query(self, query): |
|
|
client = Together() |
|
|
query, image_path = query |
|
|
response = client.chat.completions.create( |
|
|
model=self.model, |
|
|
messages=query, |
|
|
**self.kwargs |
|
|
) |
|
|
output = response.choices[0].message.content |
|
|
if hasattr(response.choices[0].message, "reasoning_content"): |
|
|
output = response.choices[0].message.reasoning_content + "\n\n" + output |
|
|
return { |
|
|
"output": output, |
|
|
"input_tokens": response.usage.prompt_tokens, |
|
|
"output_tokens": response.usage.completion_tokens, |
|
|
} |
|
|
|
|
|
def openai_batch_processing(self, queries, error_repetition=0): |
|
|
if error_repetition >= self.max_retries: |
|
|
return [ |
|
|
{ |
|
|
"output": "", |
|
|
"input_tokens": 0, |
|
|
"output_tokens": 0, |
|
|
} for _ in range(len(queries)) |
|
|
] |
|
|
text_queries = [query[0] for query in queries] |
|
|
jsonl_queries = [] |
|
|
|
|
|
for i, query in enumerate(text_queries): |
|
|
request = { |
|
|
"custom_id": f"apiquery-{i}", |
|
|
"method": "POST", |
|
|
"url": "/v1/chat/completions", |
|
|
"body": { |
|
|
"model": self.model, |
|
|
"messages": query, |
|
|
**self.kwargs |
|
|
} |
|
|
} |
|
|
jsonl_queries.append(request) |
|
|
|
|
|
client = OpenAI(api_key=self.api_key, base_url=self.base_url, |
|
|
max_retries=0) |
|
|
|
|
|
|
|
|
tmp = tempfile.NamedTemporaryFile(suffix=".jsonl", delete=False) |
|
|
with open(tmp.name, "wb") as f: |
|
|
for i, query in enumerate(jsonl_queries): |
|
|
f.write(json.dumps(query).encode("utf-8")) |
|
|
f.write(b"\n") |
|
|
|
|
|
batch_input_file = client.files.create( |
|
|
file=open(tmp.name, "rb"), |
|
|
purpose="batch" |
|
|
) |
|
|
|
|
|
batch = client.batches.create( |
|
|
input_file_id=batch_input_file.id, |
|
|
endpoint="/v1/chat/completions", |
|
|
completion_window="24h", |
|
|
) |
|
|
|
|
|
tmp.close() |
|
|
|
|
|
logger.info(f"Running {len(queries)} queries with batch ID {batch.id} using file with File ID {batch_input_file.id}.") |
|
|
|
|
|
request_counts = dict(batch.request_counts) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
batch = client.batches.retrieve(batch.id) |
|
|
except Exception as e: |
|
|
logger.warning(f"Error connecting to OpenAI. Retrying in 10s.") |
|
|
pass |
|
|
if any([request_counts[key] != dict(batch.request_counts)[key] for key in request_counts]): |
|
|
request_counts = dict(batch.request_counts) |
|
|
logger.info(f"Completed Requests Progress: {request_counts['completed']}/{len(queries)}. Errors: {request_counts['failed']}/{len(queries)}") |
|
|
if batch.status == "completed": |
|
|
break |
|
|
time.sleep(10) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
file_response = client.files.content(batch.output_file_id) |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"Error connecting to OpenAI: {e}. Retrying in 10 seconds.") |
|
|
time.sleep(10) |
|
|
continue |
|
|
|
|
|
json_response = [] |
|
|
for line in file_response.iter_lines(): |
|
|
json_response.append(json.loads(line)) |
|
|
|
|
|
outputs = [None for _ in range(len(queries))] |
|
|
repeat_indices = [] |
|
|
|
|
|
for result in json_response: |
|
|
index = int(result["custom_id"].split("-")[-1]) |
|
|
if result["response"]["status_code"] != 200: |
|
|
repeat_indices.append(index) |
|
|
logger.error(f"Error: {result['response']['status_code']}") |
|
|
else: |
|
|
try: |
|
|
outputs[index] = { |
|
|
"output": result["response"]["body"]["choices"][0]["message"]["content"], |
|
|
"input_tokens": result["response"]["body"]["usage"]["prompt_tokens"], |
|
|
"output_tokens": result["response"]["body"]["usage"]["completion_tokens"], |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Error: {e}") |
|
|
repeat_indices.append(index) |
|
|
if len(repeat_indices) > 0: |
|
|
logger.info(f"Repeating {len(repeat_indices)} queries.") |
|
|
repeat_queries = [queries[i] for i in repeat_indices] |
|
|
repeat_outputs = self.openai_batch_processing(repeat_queries, error_repetition + 1) |
|
|
for i, output in zip(repeat_indices, repeat_outputs): |
|
|
outputs[i] = output |
|
|
|
|
|
return outputs |
|
|
|
|
|
def openai_query(self, query): |
|
|
client = OpenAI(api_key=self.api_key, base_url=self.base_url, |
|
|
timeout=self.timeout, max_retries=0) |
|
|
query, image_path = query |
|
|
if image_path is not None: |
|
|
base64_image = encode_image(image_path) |
|
|
query.append({"role": "user", "content": [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}]}) |
|
|
|
|
|
if not self.openai_responses: |
|
|
response = client.chat.completions.create( |
|
|
model=self.model, |
|
|
messages=query, |
|
|
timeout=self.timeout, |
|
|
**self.kwargs |
|
|
) |
|
|
output = response.choices[0].message.content |
|
|
if hasattr(response.choices[0].message, "reasoning_content") and \ |
|
|
response.choices[0].message.reasoning_content is not None: |
|
|
output = response.choices[0].message.reasoning_content + "\n\n" + output |
|
|
input_tokens = response.usage.prompt_tokens |
|
|
output_tokens = response.usage.completion_tokens |
|
|
else: |
|
|
response = client.responses.create( |
|
|
model=self.model, |
|
|
input=query, |
|
|
timeout=self.timeout, |
|
|
**self.kwargs |
|
|
) |
|
|
output = response.output[-1].content[0].text |
|
|
input_tokens = response.usage.input_tokens |
|
|
output_tokens = response.usage.output_tokens |
|
|
|
|
|
return { |
|
|
"output": output, |
|
|
"input_tokens": input_tokens, |
|
|
"output_tokens": output_tokens, |
|
|
} |
|
|
|