RAG-Testing / RAG /llm.py
Amna2024's picture
Create llm.py
4372e7f verified
import google.generativeai as genai
from transformers.agents.llm_engine import MessageRole, get_clean_message_list
from typing import List, Dict
# ... (gemini_role_conversions from your task1.py)
# Role conversion mapping for Gemini
gemini_role_conversions = {
MessageRole.ASSISTANT: "model",
MessageRole.USER: "user",
MessageRole.SYSTEM: "user", # Gemini doesn't have a system role, prepend to user
MessageRole.TOOL_RESPONSE: "user",
}
class GeminiLLM: #renamed
def __init__(self, gemini_key: str, model_name="gemini-2.0-flash-exp"):
genai.configure(api_key=gemini_key)
self.model = genai.GenerativeModel(model_name)
def format_messages(self, messages: List[Dict]) -> List[Dict]: #renamed
# ... (Your _convert_messages logic from task1.py)
cleaned_messages = get_clean_message_list(messages, role_conversions=gemini_role_conversions)
# Handle system messages by prepending to first user message
formatted_messages = []
system_content = ""
for msg in cleaned_messages:
if msg["role"] == "user" and msg.get("content", "").startswith("System:"):
system_content += msg["content"].replace("System:", "").strip() + "\n"
else:
if system_content and msg["role"] == "user":
msg["content"] = f"{system_content}\n{msg['content']}"
system_content = ""
formatted_messages.append({
"role": msg["role"],
"parts": [msg["content"]]
})
return formatted_messages
def generate_response(self, messages: List[Dict], stop_sequences: List[str] = None) -> str: #renamed
# ... (Your Gemini generation logic from task1.py)
"""
Generate a response using the Gemini model
Args:
messages (List[Dict]): List of message dictionaries
stop_sequences (List[str], optional): List of sequences to stop generation
Returns:
str: Generated response
"""
formatted_messages = self.format_messages(messages)
# Create chat session
chat = self.model.start_chat(history=formatted_messages)
# Generate response with safety settings and parameters
response = chat.send_message(
formatted_messages[-1]["parts"][0],
generation_config={
"temperature": 0,
"max_output_tokens": 4096,
"stop_sequences": stop_sequences if stop_sequences else []
}
)
return response.text