File size: 3,940 Bytes
676582c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Gemini AI provider implementation."""
import google.generativeai as genai
from typing import List, Dict, Any
from .base import LLMProvider


class GeminiProvider(LLMProvider):
    """Google Gemini AI provider implementation.

    Uses the google-generativeai library to interact with Gemini models.
    Supports gemini-pro and other Gemini model variants.
    """

    def __init__(self, api_key: str, model_name: str = "google/gemini-2.0-flash-exp:free"):
        """Initialize the Gemini provider.

        Args:
            api_key: Google AI API key
            model_name: Gemini model name (default: gemini-pro)
        """
        super().__init__(api_key, model_name)
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel(model_name)

    async def generate_response(
        self,
        messages: List[Dict[str, str]],
        system_prompt: str | None = None,
        max_tokens: int | None = None,
        temperature: float = 0.7
    ) -> Dict[str, Any]:
        """Generate a response from Gemini.

        Args:
            messages: List of message dicts with 'role' and 'content' keys
            system_prompt: Optional system prompt to guide the AI's behavior
            max_tokens: Maximum tokens to generate in the response
            temperature: Sampling temperature (0.0 to 1.0)

        Returns:
            Dict containing:
                - content: The generated response text
                - token_count: Number of tokens used (estimated)
                - model: Model name used

        Raises:
            Exception: If the Gemini API call fails
        """
        try:
            # Build the conversation history for Gemini
            # Gemini expects a list of content parts
            chat_history = []

            # Add system prompt if provided
            if system_prompt:
                chat_history.append({
                    "role": "user",
                    "parts": [system_prompt]
                })
                chat_history.append({
                    "role": "model",
                    "parts": ["Understood. I will follow these instructions."]
                })

            # Convert messages to Gemini format
            for msg in messages:
                role = "model" if msg["role"] == "assistant" else "user"
                chat_history.append({
                    "role": role,
                    "parts": [msg["content"]]
                })

            # Start chat with history
            chat = self.model.start_chat(history=chat_history[:-1])  # Exclude last message

            # Generate response
            generation_config = genai.types.GenerationConfig(
                temperature=temperature,
                max_output_tokens=max_tokens
            )

            response = chat.send_message(
                chat_history[-1]["parts"][0],
                generation_config=generation_config
            )

            # Extract response content
            content = response.text

            # Estimate token count (Gemini doesn't provide exact counts in free tier)
            token_count = self.count_tokens(content)

            return {
                "content": content,
                "token_count": token_count,
                "model": self.model_name
            }

        except Exception as e:
            raise Exception(f"Gemini API error: {str(e)}")

    def count_tokens(self, text: str) -> int:
        """Count the number of tokens in a text string.

        Uses a simple estimation: ~4 characters per token (rough approximation).
        For more accurate counting, consider using tiktoken library.

        Args:
            text: The text to count tokens for

        Returns:
            Estimated number of tokens in the text
        """
        # Simple estimation: ~4 characters per token
        # This is a rough approximation for English text
        return len(text) // 4