Armaan11's picture
Upload folder using huggingface_hub
0828e8a verified
from loguru import logger
import re
import os
from tqdm import tqdm
from google import genai
from openai import OpenAI
from together import Together
import anthropic
from anthropic.types import ThinkingBlock, TextBlock
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
import time
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
import base64
import requests
import json
# Import tempfile to create temporary files
import tempfile
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
class APIQuery:
def __init__(self, model,
timeout=6000,
max_tokens=None,
api='openai',
max_retries=50,
concurrent_requests=30,
is_chat=True,
no_system_messages=False,
read_cost=1,
write_cost=1,
sleep_on_error=60,
sleep_after_request=0.1,
throw_error_on_failure=False,
max_tokens_param="max_tokens",
reasoning_effort=None,
batch_processing=False,
openai_responses=False,
**kwargs):
if "think" in model and api == "google":
logger.info("Google Think model does not allow chat.")
is_chat = False # think model cannot handle chat
max_tokens_param = "max_output_tokens"
if ("o1" in model or "o3" in model) and api == "openai":
logger.info("Not using system messages for o1/o3 model.")
no_system_messages = True # o1 model cannot handle system messages
max_tokens_param = "max_completion_tokens"
if "--" in model:
model, reasoning_effort = model.split("--")
logger.info(f"Model: {model}, Reasoning effort: {reasoning_effort}")
if api == "anthropic" and "claude-3-7" not in model:
logger.info("Setting max tokens to 8192 for Anthropic API.")
max_tokens = min(8192, max_tokens)
if api == "deepseek":
logger.info("Setting max tokens to 8192 for DeepSeek API.")
max_tokens = min(8192, max_tokens)
if api not in ["anthropic", "openai"] and batch_processing:
logger.warning("Batch processing is only supported for the Anthropic API and OpenAI API.")
batch_processing = False
if openai_responses:
max_tokens_param = "max_output_tokens"
self.kwarg_remover(api, model, kwargs)
self.model = model
self.kwargs = kwargs
self.kwargs[max_tokens_param] = max_tokens
self.timeout = timeout
self.max_retries = max_retries
self.throw_error_on_failure = throw_error_on_failure
self.concurrent_requests = concurrent_requests
self.is_chat = is_chat
self.no_system_messages = no_system_messages
self.sleep_on_error = sleep_on_error
self.sleep_after_request = sleep_after_request
self.read_cost = read_cost
self.write_cost = write_cost
self.batch_processing = batch_processing
self.openai_responses = openai_responses
if max_tokens is not None:
self.max_tokens_param = max_tokens_param
if reasoning_effort is not None:
if not self.openai_responses:
self.kwargs["reasoning_effort"] = reasoning_effort
else:
self.kwargs["reasoning"] = {"effort": reasoning_effort}
self.api = api
self.api_key = None
self.base_url = None
self.initialize_api_keys()
def kwarg_remover(self, api, model, kwargs):
if (api == "anthropic" and "claude-3-7" in model) or (("o1" in model or "o3" in model) and api == "openai"):
for kwarg_to_remove in ["top_p", "top_k", "temperature"]:
if kwarg_to_remove in kwargs:
logger.info(f"Removing {kwarg_to_remove} parameter for {model} model.")
del kwargs[kwarg_to_remove]
def initialize_api_keys(self):
if self.api == "openai":
self.api_key = os.getenv("OPENAI_API_KEY")
elif self.api == "together":
self.api_key = os.getenv("TOGETHER_API_KEY")
self.base_url = "https://api.together.xyz/v1"
elif self.api == "google":
self.api_key = os.getenv("GOOGLE_API_KEY")
if not "think" in self.model:
self.api = "openai"
self.base_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
elif self.api == "anthropic":
self.api_key = os.getenv("ANTHROPIC_API_KEY")
elif self.api == "hyperbolic":
self.api_key = os.getenv("HYPERBOLIC_API_KEY")
self.base_url = "https://api.hyperbolic.xyz/v1"
self.api = "openai"
elif self.api == 'sambanova':
self.api_key = os.getenv("SAMBA_API_KEY")
self.base_url = "https://api.sambanova.ai/v1"
self.api = "openai"
elif self.api == "deepseek":
self.api_key = os.getenv("DEEPSEEK_API_KEY")
self.base_url = "https://api.deepseek.com"
self.api = "openai"
elif self.api == "openrouter":
self.api_key = os.getenv("OPENROUTER_API_KEY")
self.base_url = "https://openrouter.ai/api/v1"
self.api = "openai"
elif self.api == "fireworks":
self.api_key = os.getenv("FIREWORKS_API_KEY")
self.base_url = "https://api.fireworks.ai/inference/v1"
self.api = "openai"
elif self.api == "vllm":
self.api_key = "token-abc123"
self.api = "openai"
self.base_url = f"http://localhost:8000/v1"
# command = f"vllm serve {self.model} --dtype auto --api-key token-abc123"
# Launch the command in the background.
# subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# Poll the server until it's running.
else:
raise ValueError(f"API {self.api} not supported.")
assert self.api_key is not None, f"API key not found."
def prepare_query(self, query):
query, image_path = query
if not self.is_chat:
output_query = query[0]["content"]
for message in query:
output_query += f"\n\n{'=' * 20}{message['role']}{'=' * 20}\n\n{message['content']}"
return output_query, image_path
elif self.no_system_messages:
# convert system role to user role
query = [{
"role": message["role"] if message["role"] != "system" else "user",
"content": message["content"]
} for message in query]
return query, image_path
def get_cost(self, response):
cost = response["input_tokens"] * self.read_cost + response["output_tokens"] * self.write_cost
return cost / (10 ** 6)
def run_queries(self, queries):
queries_actual = []
for query in queries:
if not isinstance(query, tuple):
queries_actual.append((query, None))
else:
queries_actual.append(query)
if self.api == "vllm":
while True:
try:
response = requests.get(f"{self.base_url}", timeout=1)
if response.status_code == 401: # unauthorized, because no api key here
break
except Exception:
pass
time.sleep(5)
logger.info("Waiting for VLLM server to start...")
logger.info("VLLM server started.")
logger.info(f"Running {len(queries_actual)} queries.")
if self.batch_processing:
if self.api == "openai":
processed_results = self.openai_batch_processing(queries_actual)
else:
processed_results = self.anthropic_batch_processing(queries_actual)
for idx, result in enumerate(processed_results):
detailed_cost = {
"cost": self.get_cost(result),
"input_tokens": result["input_tokens"],
"output_tokens": result["output_tokens"],
}
yield idx, result["output"], detailed_cost
else:
with ThreadPoolExecutor(max_workers=self.concurrent_requests) as executor:
future_to_index = {
executor.submit(self.run_query_with_retry, query): i
for i, query in enumerate(queries_actual)
}
for future in tqdm(as_completed(future_to_index), total=len(future_to_index)):
idx = future_to_index[future]
result = future.result()
detailed_cost = {
"cost": self.get_cost(result),
"input_tokens": result["input_tokens"],
"output_tokens": result["output_tokens"],
}
yield idx, result["output"], detailed_cost
def run_query_with_retry(self, query):
i = 0
while i < self.max_retries:
try:
output = self.run_query(query)
time.sleep(self.sleep_after_request)
return output
except Exception as e:
logger.error(f"Error: {e}")
time.sleep(self.sleep_on_error)
# if api error is not due to rate limit, try again
if "rate limit" not in str(e).lower() and "429" not in str(e):
i += 1
continue
if self.throw_error_on_failure:
raise ValueError("Max retries reached.")
else:
return {
"output": "",
"input_tokens": 0,
"output_tokens": 0,
}
def run_query(self, query):
query = self.prepare_query(query)
if self.api == "openai":
return self.openai_query(query)
elif self.api == "together":
return self.together_query(query)
elif self.api == "google":
return self.google_query(query)
elif self.api == "anthropic":
return self.anthropic_query(query)
def postprocess_anthropic_result(self, result):
output_text = ""
for content in result.content:
if isinstance(content, ThinkingBlock):
output_text += "<think>\n" + content.thinking + "</think>\n\n"
elif isinstance(content, TextBlock):
output_text += content.text
break
return {
"output": output_text,
"input_tokens": result.usage.input_tokens,
"output_tokens": result.usage.output_tokens,
}
def anthropic_batch_processing(self, queries, error_repetition=0):
if error_repetition >= self.max_retries:
return [
{
"output": "",
"input_tokens": 0,
"output_tokens": 0,
} for _ in range(len(queries))
]
text_queries = [query[0] for query in queries]
client = anthropic.Anthropic(
api_key=self.api_key,
max_retries=0,
)
requests = []
for i, text_query in enumerate(text_queries):
kwargs_here = self.kwargs.copy()
if text_query[0]["role"] == "system":
kwargs_here["system"] = text_query[0]["content"]
text_query = text_query[1:]
request = Request(
custom_id=f"apiquery-{i}",
params=MessageCreateParamsNonStreaming(
model=self.model,
messages=text_query,
**kwargs_here
)
)
requests.append(request)
message_batch = client.messages.batches.create(requests=requests)
logger.info(f"Running {len(queries)} queries with batch ID {message_batch.id}")
current_request_counts = dict(message_batch.request_counts)
while True:
try:
message_batch = client.messages.batches.retrieve(
message_batch_id=message_batch.id,
)
except:
logger.warning(f"Error connecting to Anthropic. Retrying in 10s.")
pass
if any([current_request_counts[key] != dict(message_batch.request_counts)[key] for key in current_request_counts]):
current_request_counts = dict(message_batch.request_counts)
error_sum = sum([current_request_counts[key] for key in current_request_counts if "succeeded" != key])
logger.info(f"Succeeded Requests Progress: {current_request_counts['succeeded']}/{len(queries)}. Errors: {error_sum}")
if message_batch.processing_status == "ended":
break
time.sleep(10)
outputs = []
repeat_indices = []
while True:
try:
results = client.messages.batches.results(
message_batch_id=message_batch.id,
)
break
except Exception as e:
logger.error(f"Error connecting to Anthropic: {e}. Retrying in 10 seconds.")
time.sleep(10)
for i, result in enumerate(results):
if result.result.type == "succeeded":
outputs.append(self.postprocess_anthropic_result(result.result.message))
else:
outputs.append(None)
repeat_indices.append(i)
if result.result.type == "errored":
logger.error(result.result.error)
if len(repeat_indices) > 0:
logger.info(f"Repeating {len(repeat_indices)} queries.")
repeat_queries = [queries[i] for i in repeat_indices]
repeat_outputs = self.anthropic_batch_processing(repeat_queries, error_repetition + 1)
for i, output in zip(repeat_indices, repeat_outputs):
outputs[i] = output
return outputs
def anthropic_query(self, query):
query, image_path = query
client = anthropic.Anthropic(
api_key=self.api_key,
max_retries=0,
timeout=self.timeout,
)
system_message = anthropic.NOT_GIVEN
if query[0]["role"] == "system":
system_message = query[0]["content"]
query = query[1:]
result = client.messages.create(
model=self.model,
messages=query,
system=system_message,
**self.kwargs
)
return self.postprocess_anthropic_result(result)
def google_query(self, query):
client = genai.Client(api_key=self.api_key, http_options={'api_version':'v1alpha'})
query, image_path = query
if image_path is not None:
file = client.files.upload(file=image_path)
query = [query, file]
# if "think" in self.model:
# config['thinking_config'] = {'include_thoughts': True}
response = client.models.generate_content(
model=self.model,
contents=query,
# config=config,
)
return {
"output": "\n\n".join([response.candidates[0].content.parts[i].text
for i in range(len(response.candidates[0].content.parts))]),
"input_tokens": response.usage_metadata.prompt_token_count,
"output_tokens": response.usage_metadata.candidates_token_count,
}
def together_query(self, query):
client = Together()
query, image_path = query
response = client.chat.completions.create(
model=self.model,
messages=query,
**self.kwargs
)
output = response.choices[0].message.content
if hasattr(response.choices[0].message, "reasoning_content"):
output = response.choices[0].message.reasoning_content + "\n\n" + output
return {
"output": output,
"input_tokens": response.usage.prompt_tokens,
"output_tokens": response.usage.completion_tokens,
}
def openai_batch_processing(self, queries, error_repetition=0):
if error_repetition >= self.max_retries:
return [
{
"output": "",
"input_tokens": 0,
"output_tokens": 0,
} for _ in range(len(queries))
]
text_queries = [query[0] for query in queries]
jsonl_queries = []
for i, query in enumerate(text_queries):
request = {
"custom_id": f"apiquery-{i}",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": self.model,
"messages": query,
**self.kwargs
}
}
jsonl_queries.append(request)
client = OpenAI(api_key=self.api_key, base_url=self.base_url,
max_retries=0)
# create temp file
tmp = tempfile.NamedTemporaryFile(suffix=".jsonl", delete=False)
with open(tmp.name, "wb") as f:
for i, query in enumerate(jsonl_queries):
f.write(json.dumps(query).encode("utf-8"))
f.write(b"\n")
batch_input_file = client.files.create(
file=open(tmp.name, "rb"),
purpose="batch"
)
batch = client.batches.create(
input_file_id=batch_input_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
# close tmp file
tmp.close()
logger.info(f"Running {len(queries)} queries with batch ID {batch.id} using file with File ID {batch_input_file.id}.")
request_counts = dict(batch.request_counts)
while True:
try:
batch = client.batches.retrieve(batch.id)
except Exception as e:
logger.warning(f"Error connecting to OpenAI. Retrying in 10s.")
pass
if any([request_counts[key] != dict(batch.request_counts)[key] for key in request_counts]):
request_counts = dict(batch.request_counts)
logger.info(f"Completed Requests Progress: {request_counts['completed']}/{len(queries)}. Errors: {request_counts['failed']}/{len(queries)}")
if batch.status == "completed":
break
time.sleep(10)
while True:
try:
file_response = client.files.content(batch.output_file_id)
break
except Exception as e:
logger.error(f"Error connecting to OpenAI: {e}. Retrying in 10 seconds.")
time.sleep(10)
continue
json_response = []
for line in file_response.iter_lines():
json_response.append(json.loads(line))
outputs = [None for _ in range(len(queries))]
repeat_indices = []
for result in json_response:
index = int(result["custom_id"].split("-")[-1])
if result["response"]["status_code"] != 200:
repeat_indices.append(index)
logger.error(f"Error: {result['response']['status_code']}")
else:
try:
outputs[index] = {
"output": result["response"]["body"]["choices"][0]["message"]["content"],
"input_tokens": result["response"]["body"]["usage"]["prompt_tokens"],
"output_tokens": result["response"]["body"]["usage"]["completion_tokens"],
}
except Exception as e:
logger.error(f"Error: {e}")
repeat_indices.append(index)
if len(repeat_indices) > 0:
logger.info(f"Repeating {len(repeat_indices)} queries.")
repeat_queries = [queries[i] for i in repeat_indices]
repeat_outputs = self.openai_batch_processing(repeat_queries, error_repetition + 1)
for i, output in zip(repeat_indices, repeat_outputs):
outputs[i] = output
return outputs
def openai_query(self, query):
client = OpenAI(api_key=self.api_key, base_url=self.base_url,
timeout=self.timeout, max_retries=0)
query, image_path = query
if image_path is not None:
base64_image = encode_image(image_path)
query.append({"role": "user", "content": [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}]})
if not self.openai_responses:
response = client.chat.completions.create(
model=self.model,
messages=query,
timeout=self.timeout,
**self.kwargs
)
output = response.choices[0].message.content
if hasattr(response.choices[0].message, "reasoning_content") and \
response.choices[0].message.reasoning_content is not None:
output = response.choices[0].message.reasoning_content + "\n\n" + output
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
else:
response = client.responses.create(
model=self.model,
input=query,
timeout=self.timeout,
**self.kwargs
)
output = response.output[-1].content[0].text
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
return {
"output": output,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
}