api-web-crawler / app /util /gen_ai_base.py
mrfirdauss's picture
init: api-web-crawler to hf
40c79b0
# gen_ai_base.py
from google import genai
import asyncio
from pydantic import BaseModel
from typing import Optional
class GenAIBaseClient:
def __init__(self, api_key: str, model_name: str = "gemini-2.0-flash"):
self.client = genai.Client(api_key=api_key)
self.model_name = model_name
self.token_usage = {
"input": 0,
"output": 0,
"total": 0
}
async def formated_prompt(self, prompt: str, response_schema: Optional[BaseModel] = None) -> dict:
"""Send a prompt to the Gemini model and return the response along with token usage,
running the synchronous API call in a separate thread."""
# Define a synchronous helper function to execute the blocking call
def blocking_call():
return self.client.models.generate_content(
model=self.model_name,
contents=prompt,
config=genai.types.GenerateContentConfig(
response_mime_type="application/json" if response_schema else "text/plain",
response_schema=response_schema,
),
)
try:
response = await asyncio.to_thread(blocking_call)
usage_info = self._get_usage_from_response(response)
parsed_data = None
if response_schema:
parsed_data = response.parsed.model_dump() if response.parsed else None
else:
parsed_data = response.text
return {
"parsed": parsed_data,
"usage": usage_info
}
except Exception as e:
print(f"Error during GenAI prompt: {e}")
return {
"response": None,
"parsed": None,
"usage": {"input": 0, "output": 0, "total": 0}
}
def _get_usage_from_response(self, response):
"""Extract and accumulate token usage from Gemini response."""
usage = getattr(response, "usage_metadata", None)
if not usage:
return {"input": 0, "output": 0, "total": 0}
input_tokens = getattr(usage, "prompt_token_count", 0) or 0
output_tokens = getattr(usage, "candidates_token_count", 0) or 0
total_tokens = getattr(usage, "total_token_count", 0) or (input_tokens + output_tokens)
self.token_usage["input"] += input_tokens
self.token_usage["output"] += output_tokens
self.token_usage["total"] += total_tokens
return {
"input": input_tokens,
"output": output_tokens,
"total": total_tokens
}