Pulastya0 commited on
Commit
b7cdb59
Β·
verified Β·
1 Parent(s): b551ce5

Update agent_langchain.py

Browse files
Files changed (1) hide show
  1. agent_langchain.py +72 -89
agent_langchain.py CHANGED
@@ -1,28 +1,29 @@
1
  import os
2
  import requests
3
  import torch
 
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import chromadb
6
  from chromadb.config import Settings
7
- from chromadb.utils import embedding_functions
8
- from langchain.agents import initialize_agent, Tool
9
- from langchain.agents import AgentType
10
  from langchain.memory import ConversationBufferMemory
11
 
12
- # -------------------------------
13
- # Environment & URLs
14
- # -------------------------------
15
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
16
- GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent"
17
  ROUTING_URL = os.environ.get("ROUTING_URL") # Space 2 URL
18
  SPACE_URL = os.environ.get("SPACE_URL", "http://localhost:7860")
19
 
 
20
  os.environ["HF_HOME"] = "/tmp/huggingface"
21
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
22
  os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface"
23
- # -------------------------------
24
- # Label Dictionary
25
- # -------------------------------
 
26
  LABEL_DICTIONARY = {
27
  "I1": "Low Impact",
28
  "I2": "Medium Impact",
@@ -39,32 +40,15 @@ LABEL_DICTIONARY = {
39
  "T5": "Question"
40
  }
41
 
42
- # -------------------------------
43
- # Load Classification Model
44
- # -------------------------------
45
  clf_model_name = "DavinciTech/BERT_Categorizer"
46
  clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name)
47
  clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name)
48
 
49
- # -------------------------------
50
- # Initialize ChromaDB Client for KB
51
- # -------------------------------
52
- # βœ… Use new API β€” persistent on Hugging Face writable directory
53
- chroma_client = chromadb.PersistentClient(path="/tmp/chroma")
54
-
55
- # βœ… Create or get your KB collection
56
- kb_collection = chroma_client.get_or_create_collection("Knowledge_Base")
57
-
58
- COLLECTION_NAME = "Knowledge_Base"
59
- try:
60
- kb_collection = chroma_client.get_collection(COLLECTION_NAME)
61
- except:
62
- kb_collection = None
63
-
64
- # -------------------------------
65
- # Classification Function
66
- # -------------------------------
67
  def classify_ticket(text):
 
68
  inputs = clf_tokenizer(text, return_tensors="pt", truncation=True)
69
  outputs = clf_model(**inputs)
70
  logits = outputs.logits[0]
@@ -79,10 +63,11 @@ def classify_ticket(text):
79
  "type": LABEL_DICTIONARY[f"T{type_idx}"]
80
  }
81
 
82
- # -------------------------------
83
- # Routing Function
84
- # -------------------------------
85
  def call_routing(text, retries=3, delay=1):
 
86
  url = ROUTING_URL if ROUTING_URL else f"{SPACE_URL}/route"
87
  for attempt in range(retries):
88
  try:
@@ -96,89 +81,86 @@ def call_routing(text, retries=3, delay=1):
96
  else:
97
  return "General IT"
98
 
99
- # -------------------------------
100
- # KB Query
101
- # -------------------------------
 
 
 
 
 
 
 
 
 
 
 
102
  def query_kb(text, top_k=1):
 
103
  if not kb_collection:
104
- return {"answer": "⚠️ KB not set up. Call /setup first.", "confidence": 0.0}
105
 
106
  results = kb_collection.query(query_texts=[text], n_results=top_k)
107
- if not results or len(results['documents'][0]) == 0:
108
  return {"answer": "No relevant KB found.", "confidence": 0.0}
109
 
110
  return {
111
- "answer": results['documents'][0][0],
112
- "confidence": results['distances'][0][0] if results.get('distances') else 0.0,
113
- "metadata": results['metadatas'][0][0] if results['metadatas'][0] else {}
114
  }
115
 
116
- # -------------------------------
117
- # Gemini LLM Wrapper
118
- # -------------------------------
119
- class GeminiLLM:
120
- def __init__(self, api_key=GEMINI_API_KEY):
121
- self.api_key = api_key
122
- self.api_url = GEMINI_API_URL
123
-
124
- def __call__(self, prompt: str):
125
- if not self.api_key:
126
- return {"text": "⚠️ Gemini API key not set."}
127
- payload = {"contents": [{"parts": [{"text": prompt}]}]}
128
- headers = {"Authorization": f"Bearer {self.api_key}"}
129
- try:
130
- resp = requests.post(self.api_url, json=payload, headers=headers)
131
- resp.raise_for_status()
132
- data = resp.json()
133
- text = data.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "")
134
- return text
135
- except:
136
- return "⚠️ Gemini API call failed."
137
-
138
- # -------------------------------
139
- # Define LangChain Tools
140
- # -------------------------------
141
  tools = [
142
  Tool(
143
  name="TicketClassifier",
144
  func=lambda text: classify_ticket(text),
145
- description="Classifies a ticket into impact, urgency, and type. Mandatory tool."
146
  ),
147
  Tool(
148
  name="RoutingTool",
149
  func=lambda text: call_routing(text),
150
- description="Assigns a department for the ticket via Space 2. Mandatory tool."
151
  ),
152
  Tool(
153
  name="KnowledgeBaseTool",
154
  func=lambda text: query_kb(text)["answer"],
155
- description="Searches KB for relevant solution. Returns answer text."
156
  )
157
  ]
158
 
159
- # -------------------------------
160
- # Initialize Memory
161
- # -------------------------------
162
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
163
 
164
- # -------------------------------
165
- # Initialize Agent
166
- # -------------------------------
167
  agent_executor = initialize_agent(
168
  tools=tools,
169
- llm=GeminiLLM(),
170
  agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
171
  memory=memory,
172
  verbose=False
173
  )
174
 
175
- # -------------------------------
176
- # Process Ticket Function
177
- # -------------------------------
178
  def process_ticket_langchain(ticket_text):
 
179
  reasoning_trace = []
180
 
181
- # Step 1: Classifier
182
  classification = classify_ticket(ticket_text)
183
  reasoning_trace.append(f"[Classifier] Impact: {classification['impact']}, Urgency: {classification['urgency']}, Type: {classification['type']}")
184
 
@@ -188,26 +170,27 @@ def process_ticket_langchain(ticket_text):
188
 
189
  # Step 3: KB Search
190
  kb_result = query_kb(ticket_text)
191
- reasoning_trace.append(f"[KB Search] Top answer: '{kb_result['answer']}' (confidence: {kb_result['confidence']})")
192
 
193
- # Step 4: Decision KB vs LLM
194
  if kb_result["confidence"] >= 0.75:
195
  final_answer = kb_result["answer"]
196
  status = "resolved"
197
- reasoning_trace.append("[Decision] KB confidence high β†’ ticket resolved via KB.")
198
  else:
199
  llm_prompt = f"""
200
- You are a professional IT helpdesk assistant.
201
  A user submitted the following ticket: "{ticket_text}"
 
202
  Ticket classification: {classification}
203
  Assigned department: {department}
204
- KB Search result: {kb_result['answer']} (confidence: {kb_result['confidence']})
205
 
206
- Provide a professional and descriptive solution or guidance based on this information.
207
  """
208
- final_answer = GeminiLLM()(llm_prompt)
209
  status = "escalated"
210
- reasoning_trace.append("[Decision] KB confidence low β†’ ticket escalated via Gemini LLM.")
211
 
212
  return {
213
  "status": status,
 
1
  import os
2
  import requests
3
  import torch
4
+ import time
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  import chromadb
7
  from chromadb.config import Settings
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from langchain.agents import initialize_agent, Tool, AgentType
 
10
  from langchain.memory import ConversationBufferMemory
11
 
12
+ # ==============================================================
13
+ # 🌐 ENVIRONMENT & GLOBAL SETTINGS
14
+ # ==============================================================
15
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
 
16
  ROUTING_URL = os.environ.get("ROUTING_URL") # Space 2 URL
17
  SPACE_URL = os.environ.get("SPACE_URL", "http://localhost:7860")
18
 
19
+ # Hugging Face Space writable paths
20
  os.environ["HF_HOME"] = "/tmp/huggingface"
21
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
22
  os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface"
23
+
24
+ # ==============================================================
25
+ # 🏷️ LABEL DICTIONARY
26
+ # ==============================================================
27
  LABEL_DICTIONARY = {
28
  "I1": "Low Impact",
29
  "I2": "Medium Impact",
 
40
  "T5": "Question"
41
  }
42
 
43
+ # ==============================================================
44
+ # πŸ€– LOAD CLASSIFICATION MODEL
45
+ # ==============================================================
46
  clf_model_name = "DavinciTech/BERT_Categorizer"
47
  clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name)
48
  clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def classify_ticket(text):
51
+ """Classify the ticket into Impact, Urgency, and Type."""
52
  inputs = clf_tokenizer(text, return_tensors="pt", truncation=True)
53
  outputs = clf_model(**inputs)
54
  logits = outputs.logits[0]
 
63
  "type": LABEL_DICTIONARY[f"T{type_idx}"]
64
  }
65
 
66
+ # ==============================================================
67
+ # 🧭 ROUTING FUNCTION (Space 2)
68
+ # ==============================================================
69
  def call_routing(text, retries=3, delay=1):
70
+ """Call Space 2 routing endpoint."""
71
  url = ROUTING_URL if ROUTING_URL else f"{SPACE_URL}/route"
72
  for attempt in range(retries):
73
  try:
 
81
  else:
82
  return "General IT"
83
 
84
+ # ==============================================================
85
+ # πŸ“š KNOWLEDGE BASE SETUP
86
+ # ==============================================================
87
+ # Persistent Chroma client (new API)
88
+ chroma_client = chromadb.PersistentClient(path="/tmp/chroma")
89
+
90
+ COLLECTION_NAME = "knowledge_base"
91
+
92
+ try:
93
+ kb_collection = chroma_client.get_or_create_collection(COLLECTION_NAME)
94
+ except Exception as e:
95
+ kb_collection = None
96
+ print("⚠️ Could not initialize KB:", e)
97
+
98
  def query_kb(text, top_k=1):
99
+ """Query the knowledge base for relevant solutions."""
100
  if not kb_collection:
101
+ return {"answer": "⚠️ KB not set up.", "confidence": 0.0}
102
 
103
  results = kb_collection.query(query_texts=[text], n_results=top_k)
104
+ if not results or not results.get("documents") or len(results["documents"][0]) == 0:
105
  return {"answer": "No relevant KB found.", "confidence": 0.0}
106
 
107
  return {
108
+ "answer": results["documents"][0][0],
109
+ "confidence": results.get("distances", [[0.0]])[0][0],
110
+ "metadata": results.get("metadatas", [[{}]])[0][0]
111
  }
112
 
113
+ # ==============================================================
114
+ # 🧠 GEMINI LLM (Official LangChain Integration)
115
+ # ==============================================================
116
+ llm = ChatGoogleGenerativeAI(
117
+ model="gemini-2.5-flash",
118
+ temperature=0.3,
119
+ google_api_key=GEMINI_API_KEY
120
+ )
121
+
122
+ # ==============================================================
123
+ # 🧰 DEFINE LANGCHAIN TOOLS
124
+ # ==============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  tools = [
126
  Tool(
127
  name="TicketClassifier",
128
  func=lambda text: classify_ticket(text),
129
+ description="Classifies the ticket into impact, urgency, and type. Mandatory tool."
130
  ),
131
  Tool(
132
  name="RoutingTool",
133
  func=lambda text: call_routing(text),
134
+ description="Determines which department should handle the ticket (via Space 2). Mandatory tool."
135
  ),
136
  Tool(
137
  name="KnowledgeBaseTool",
138
  func=lambda text: query_kb(text)["answer"],
139
+ description="Searches the KB for relevant solutions. Returns a descriptive answer."
140
  )
141
  ]
142
 
143
+ # ==============================================================
144
+ # πŸ’¬ MEMORY & AGENT INITIALIZATION
145
+ # ==============================================================
146
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
147
 
 
 
 
148
  agent_executor = initialize_agent(
149
  tools=tools,
150
+ llm=llm,
151
  agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
152
  memory=memory,
153
  verbose=False
154
  )
155
 
156
+ # ==============================================================
157
+ # 🧾 MAIN TICKET PROCESSOR
158
+ # ==============================================================
159
  def process_ticket_langchain(ticket_text):
160
+ """Full pipeline: classify β†’ route β†’ query KB β†’ decide KB vs Gemini."""
161
  reasoning_trace = []
162
 
163
+ # Step 1: Classification
164
  classification = classify_ticket(ticket_text)
165
  reasoning_trace.append(f"[Classifier] Impact: {classification['impact']}, Urgency: {classification['urgency']}, Type: {classification['type']}")
166
 
 
170
 
171
  # Step 3: KB Search
172
  kb_result = query_kb(ticket_text)
173
+ reasoning_trace.append(f"[KB Search] Top Answer: '{kb_result['answer']}' (confidence: {kb_result['confidence']})")
174
 
175
+ # Step 4: KB vs LLM Decision
176
  if kb_result["confidence"] >= 0.75:
177
  final_answer = kb_result["answer"]
178
  status = "resolved"
179
+ reasoning_trace.append("[Decision] High KB confidence β†’ ticket resolved via KB.")
180
  else:
181
  llm_prompt = f"""
182
+ You are a professional IT helpdesk agent.
183
  A user submitted the following ticket: "{ticket_text}"
184
+
185
  Ticket classification: {classification}
186
  Assigned department: {department}
187
+ Knowledge base result: {kb_result['answer']} (confidence: {kb_result['confidence']})
188
 
189
+ Please provide a clear, descriptive, and professional IT helpdesk response.
190
  """
191
+ final_answer = llm.invoke(llm_prompt).content
192
  status = "escalated"
193
+ reasoning_trace.append("[Decision] Low KB confidence β†’ fallback to Gemini LLM for escalation.")
194
 
195
  return {
196
  "status": status,