File size: 2,675 Bytes
4372e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import google.generativeai as genai
from transformers.agents.llm_engine import MessageRole, get_clean_message_list
from typing import List, Dict

# ... (gemini_role_conversions from your task1.py)
# Role conversion mapping for Gemini
gemini_role_conversions = {
    MessageRole.ASSISTANT: "model",
    MessageRole.USER: "user",
    MessageRole.SYSTEM: "user",  # Gemini doesn't have a system role, prepend to user
    MessageRole.TOOL_RESPONSE: "user",
}
class GeminiLLM: #renamed
    def __init__(self, gemini_key: str, model_name="gemini-2.0-flash-exp"):
        genai.configure(api_key=gemini_key)
        self.model = genai.GenerativeModel(model_name)

    def format_messages(self, messages: List[Dict]) -> List[Dict]: #renamed
        # ... (Your _convert_messages logic from task1.py)
        cleaned_messages = get_clean_message_list(messages, role_conversions=gemini_role_conversions)

        # Handle system messages by prepending to first user message
        formatted_messages = []
        system_content = ""

        for msg in cleaned_messages:
            if msg["role"] == "user" and msg.get("content", "").startswith("System:"):
                system_content += msg["content"].replace("System:", "").strip() + "\n"
            else:
                if system_content and msg["role"] == "user":
                    msg["content"] = f"{system_content}\n{msg['content']}"
                    system_content = ""
                formatted_messages.append({
                    "role": msg["role"],
                    "parts": [msg["content"]]
                })

        return formatted_messages

    def generate_response(self, messages: List[Dict], stop_sequences: List[str] = None) -> str: #renamed
        # ... (Your Gemini generation logic from task1.py)
        """
                Generate a response using the Gemini model

                Args:
                    messages (List[Dict]): List of message dictionaries
                    stop_sequences (List[str], optional): List of sequences to stop generation

                Returns:
                    str: Generated response
                """
        formatted_messages = self.format_messages(messages)

        # Create chat session
        chat = self.model.start_chat(history=formatted_messages)

        # Generate response with safety settings and parameters
        response = chat.send_message(
            formatted_messages[-1]["parts"][0],
            generation_config={
                "temperature": 0,
                "max_output_tokens": 4096,
                "stop_sequences": stop_sequences if stop_sequences else []
            }
        )

        return response.text