| import os |
| from abc import ABC, abstractmethod |
| from google import genai |
| from google.genai import types |
| from pydantic import BaseModel |
| print("dfdf") |
| class LLMClient(ABC): |
| """ |
| Abstract base class for calling LLM APIs. |
| """ |
| def __init__(self, config: dict = None): |
| """ |
| Initializes the LLMClient with a configuration dictionary. |
| |
| Args: |
| config (dict): Configuration settings for the LLM client. |
| """ |
| self.config = config or {} |
|
|
| @abstractmethod |
| def call_api(self, prompt: str) -> str: |
| """ |
| Call the underlying LLM API with the given prompt. |
| |
| Args: |
| prompt (str): The prompt or input text for the LLM. |
| |
| Returns: |
| str: The response from the LLM. |
| """ |
| pass |
|
|
|
|
| class GeminiLLMClient(LLMClient): |
| """ |
| Concrete implementation of LLMClient for the Gemini API. |
| """ |
|
|
| def __init__(self, config: dict): |
| """ |
| Initializes the GeminiLLMClient with an API key, model name, and optional generation settings. |
| |
| Args: |
| config (dict): Configuration containing: |
| - 'api_key': (optional) API key for Gemini (falls back to GEMINI_API_KEY env var) |
| - 'model_name': (optional) the model to use (default 'gemini-2.0-flash') |
| - 'generation_config': (optional) dict of GenerateContentConfig parameters |
| """ |
| api_key = config.get("api_key") or os.environ.get("GEMINI_API_KEY") |
| if not api_key: |
| raise ValueError( |
| "API key for Gemini must be provided in config['api_key'] or GEMINI_API_KEY env var." |
| ) |
| self.client = genai.Client(api_key=api_key) |
| self.model_name = config.get("model_name", "gemini-2.0-flash") |
| |
| gen_conf = config.get("generation_config", {}) |
| self.generate_config = types.GenerateContentConfig( |
| response_mime_type=gen_conf.get("response_mime_type", "text/plain"), |
| temperature=gen_conf.get("temperature"), |
| max_output_tokens=gen_conf.get("max_output_tokens"), |
| top_p=gen_conf.get("top_p"), |
| top_k=gen_conf.get("top_k"), |
| |
| ) |
|
|
| def call_api(self, prompt: str) -> str: |
| """ |
| Call the Gemini API with the given prompt (non-streaming). |
| |
| Args: |
| prompt (str): The input text for the API. |
| |
| Returns: |
| str: The generated text from the Gemini API. |
| """ |
| contents = [ |
| types.Content( |
| role="user", |
| parts=[types.Part.from_text(text=prompt)], |
| ) |
| ] |
|
|
| |
| response = self.client.models.generate_content( |
| model=self.model_name, |
| contents=contents, |
| config=self.generate_config, |
| ) |
|
|
| |
| return response.text |
|
|
| |
|
|
| class AIExtractor: |
| def __init__(self, llm_client: LLMClient, prompt_template: str): |
| """ |
| Initializes the AIExtractor with a specific LLM client and configuration. |
| |
| Args: |
| llm_client (LLMClient): An instance of a class that implements the LLMClient interface. |
| prompt_template (str): The template to use for generating prompts for the LLM. |
| should contain placeholders for dynamic content. |
| e.g., "Extract the following information: {content} based on schema: {schema}" |
| """ |
| self.llm_client = llm_client |
| self.prompt_template = prompt_template |
|
|
| def extract(self, content: str, schema: BaseModel) -> str: |
| """ |
| Extracts structured information from the given content based on the provided schema. |
| |
| Args: |
| content (str): The raw content to extract information from. |
| schema (BaseModel): A Pydantic model defining the structure of the expected output. |
| |
| Returns: |
| str: The structured JSON object as a string. |
| """ |
| prompt = self.prompt_template.format(content=content, schema=schema.model_json_schema()) |
| |
| response = self.llm_client.call_api(prompt) |
| return response |
|
|
| |
|
|