Amna2024 commited on
Commit
4372e7f
·
verified ·
1 Parent(s): d6e43ef

Create llm.py

Browse files
Files changed (1) hide show
  1. RAG/llm.py +67 -0
RAG/llm.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import google.generativeai as genai
2
+ from transformers.agents.llm_engine import MessageRole, get_clean_message_list
3
+ from typing import List, Dict
4
+
5
+ # ... (gemini_role_conversions from your task1.py)
6
+ # Role conversion mapping for Gemini
7
+ gemini_role_conversions = {
8
+ MessageRole.ASSISTANT: "model",
9
+ MessageRole.USER: "user",
10
+ MessageRole.SYSTEM: "user", # Gemini doesn't have a system role, prepend to user
11
+ MessageRole.TOOL_RESPONSE: "user",
12
+ }
13
+ class GeminiLLM: #renamed
14
+ def __init__(self, gemini_key: str, model_name="gemini-2.0-flash-exp"):
15
+ genai.configure(api_key=gemini_key)
16
+ self.model = genai.GenerativeModel(model_name)
17
+
18
+ def format_messages(self, messages: List[Dict]) -> List[Dict]: #renamed
19
+ # ... (Your _convert_messages logic from task1.py)
20
+ cleaned_messages = get_clean_message_list(messages, role_conversions=gemini_role_conversions)
21
+
22
+ # Handle system messages by prepending to first user message
23
+ formatted_messages = []
24
+ system_content = ""
25
+
26
+ for msg in cleaned_messages:
27
+ if msg["role"] == "user" and msg.get("content", "").startswith("System:"):
28
+ system_content += msg["content"].replace("System:", "").strip() + "\n"
29
+ else:
30
+ if system_content and msg["role"] == "user":
31
+ msg["content"] = f"{system_content}\n{msg['content']}"
32
+ system_content = ""
33
+ formatted_messages.append({
34
+ "role": msg["role"],
35
+ "parts": [msg["content"]]
36
+ })
37
+
38
+ return formatted_messages
39
+
40
+ def generate_response(self, messages: List[Dict], stop_sequences: List[str] = None) -> str: #renamed
41
+ # ... (Your Gemini generation logic from task1.py)
42
+ """
43
+ Generate a response using the Gemini model
44
+
45
+ Args:
46
+ messages (List[Dict]): List of message dictionaries
47
+ stop_sequences (List[str], optional): List of sequences to stop generation
48
+
49
+ Returns:
50
+ str: Generated response
51
+ """
52
+ formatted_messages = self.format_messages(messages)
53
+
54
+ # Create chat session
55
+ chat = self.model.start_chat(history=formatted_messages)
56
+
57
+ # Generate response with safety settings and parameters
58
+ response = chat.send_message(
59
+ formatted_messages[-1]["parts"][0],
60
+ generation_config={
61
+ "temperature": 0,
62
+ "max_output_tokens": 4096,
63
+ "stop_sequences": stop_sequences if stop_sequences else []
64
+ }
65
+ )
66
+
67
+ return response.text