Spaces:
Sleeping
Sleeping
| # gen_ai_base.py | |
| from google import genai | |
| import asyncio | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| class GenAIBaseClient: | |
| def __init__(self, api_key: str, model_name: str = "gemini-2.0-flash"): | |
| self.client = genai.Client(api_key=api_key) | |
| self.model_name = model_name | |
| self.token_usage = { | |
| "input": 0, | |
| "output": 0, | |
| "total": 0 | |
| } | |
| async def formated_prompt(self, prompt: str, response_schema: Optional[BaseModel] = None) -> dict: | |
| """Send a prompt to the Gemini model and return the response along with token usage, | |
| running the synchronous API call in a separate thread.""" | |
| # Define a synchronous helper function to execute the blocking call | |
| def blocking_call(): | |
| return self.client.models.generate_content( | |
| model=self.model_name, | |
| contents=prompt, | |
| config=genai.types.GenerateContentConfig( | |
| response_mime_type="application/json" if response_schema else "text/plain", | |
| response_schema=response_schema, | |
| ), | |
| ) | |
| try: | |
| response = await asyncio.to_thread(blocking_call) | |
| usage_info = self._get_usage_from_response(response) | |
| parsed_data = None | |
| if response_schema: | |
| parsed_data = response.parsed.model_dump() if response.parsed else None | |
| else: | |
| parsed_data = response.text | |
| return { | |
| "parsed": parsed_data, | |
| "usage": usage_info | |
| } | |
| except Exception as e: | |
| print(f"Error during GenAI prompt: {e}") | |
| return { | |
| "response": None, | |
| "parsed": None, | |
| "usage": {"input": 0, "output": 0, "total": 0} | |
| } | |
| def _get_usage_from_response(self, response): | |
| """Extract and accumulate token usage from Gemini response.""" | |
| usage = getattr(response, "usage_metadata", None) | |
| if not usage: | |
| return {"input": 0, "output": 0, "total": 0} | |
| input_tokens = getattr(usage, "prompt_token_count", 0) or 0 | |
| output_tokens = getattr(usage, "candidates_token_count", 0) or 0 | |
| total_tokens = getattr(usage, "total_token_count", 0) or (input_tokens + output_tokens) | |
| self.token_usage["input"] += input_tokens | |
| self.token_usage["output"] += output_tokens | |
| self.token_usage["total"] += total_tokens | |
| return { | |
| "input": input_tokens, | |
| "output": output_tokens, | |
| "total": total_tokens | |
| } |