Pulastya0 commited on
Commit
aa4ac8d
·
verified ·
1 Parent(s): 332dd63

Upload 6 files

Browse files
Files changed (6) hide show
  1. .dockerignore +6 -0
  2. Dockerfile +33 -0
  3. agent_langchain.py +214 -0
  4. app.py +121 -0
  5. main.py +121 -0
  6. requirements.txt +11 -0
.dockerignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .git/
6
+ data/huggingface-cache/*
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use official Python base
2
+ FROM python:3.11-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Set locales and UTF-8
8
+ ENV LANG=C.UTF-8
9
+ ENV LC_ALL=C.UTF-8
10
+
11
+ # Set Hugging Face cache
12
+ ENV HF_HOME="/data/huggingface-cache"
13
+ ENV TRANSFORMERS_CACHE="/data/huggingface-cache"
14
+
15
+ # Install system dependencies
16
+ RUN apt-get update && \
17
+ apt-get install -y --no-install-recommends git build-essential && \
18
+ rm -rf /var/lib/apt/lists/*
19
+
20
+ # Copy requirements
21
+ COPY requirements.txt .
22
+
23
+ # Install Python dependencies
24
+ RUN pip install --no-cache-dir -r requirements.txt
25
+
26
+ # Copy app files
27
+ COPY . .
28
+
29
+ # Expose port
30
+ EXPOSE 7860
31
+
32
+ # Start Uvicorn
33
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
agent_langchain.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import chromadb
6
+ from chromadb.config import Settings
7
+ from chromadb.utils import embedding_functions
8
+ from langchain.agents import initialize_agent, Tool
9
+ from langchain.agents import AgentType
10
+ from langchain.memory import ConversationBufferMemory
11
+
12
+ # -------------------------------
13
+ # Environment & URLs
14
+ # -------------------------------
15
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
16
+ GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"
17
+ ROUTING_URL = os.environ.get("ROUTING_URL") # Space 2 URL
18
+ SPACE_URL = os.environ.get("SPACE_URL", "http://localhost:7860")
19
+
20
+ # -------------------------------
21
+ # Label Dictionary
22
+ # -------------------------------
23
+ LABEL_DICTIONARY = {
24
+ "I1": "Low Impact",
25
+ "I2": "Medium Impact",
26
+ "I3": "High Impact",
27
+ "I4": "Critical Impact",
28
+ "U1": "Low Urgency",
29
+ "U2": "Medium Urgency",
30
+ "U3": "High Urgency",
31
+ "U4": "Critical Urgency",
32
+ "T1": "Information",
33
+ "T2": "Incident",
34
+ "T3": "Problem",
35
+ "T4": "Request",
36
+ "T5": "Question"
37
+ }
38
+
39
+ # -------------------------------
40
+ # Load Classification Model
41
+ # -------------------------------
42
+ clf_model_name = "DavinciTech/BERT_Categorizer"
43
+ clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name)
44
+ clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name)
45
+
46
+ # -------------------------------
47
+ # Initialize ChromaDB Client for KB
48
+ # -------------------------------
49
+ chroma_client = chromadb.Client(Settings(
50
+ chroma_db_impl="duckdb+parquet",
51
+ persist_directory="/data/chroma_db"
52
+ ))
53
+
54
+ COLLECTION_NAME = "kb_collection"
55
+ try:
56
+ kb_collection = chroma_client.get_collection(COLLECTION_NAME)
57
+ except:
58
+ kb_collection = None
59
+
60
+ # -------------------------------
61
+ # Classification Function
62
+ # -------------------------------
63
+ def classify_ticket(text):
64
+ inputs = clf_tokenizer(text, return_tensors="pt", truncation=True)
65
+ outputs = clf_model(**inputs)
66
+ logits = outputs.logits[0]
67
+
68
+ impact_idx = torch.argmax(logits[:4]).item() + 1
69
+ urgency_idx = torch.argmax(logits[4:8]).item() + 1
70
+ type_idx = torch.argmax(logits[8:]).item() + 1
71
+
72
+ return {
73
+ "impact": LABEL_DICTIONARY[f"I{impact_idx}"],
74
+ "urgency": LABEL_DICTIONARY[f"U{urgency_idx}"],
75
+ "type": LABEL_DICTIONARY[f"T{type_idx}"]
76
+ }
77
+
78
+ # -------------------------------
79
+ # Routing Function
80
+ # -------------------------------
81
+ def call_routing(text, retries=3, delay=1):
82
+ url = ROUTING_URL if ROUTING_URL else f"{SPACE_URL}/route"
83
+ for attempt in range(retries):
84
+ try:
85
+ resp = requests.post(url, json={"text": text}, timeout=5)
86
+ resp.raise_for_status()
87
+ data = resp.json()
88
+ return data.get("department", "General IT")
89
+ except Exception:
90
+ if attempt < retries - 1:
91
+ time.sleep(delay)
92
+ else:
93
+ return "General IT"
94
+
95
+ # -------------------------------
96
+ # KB Query
97
+ # -------------------------------
98
+ def query_kb(text, top_k=1):
99
+ if not kb_collection:
100
+ return {"answer": "⚠️ KB not set up. Call /setup first.", "confidence": 0.0}
101
+
102
+ results = kb_collection.query(query_texts=[text], n_results=top_k)
103
+ if not results or len(results['documents'][0]) == 0:
104
+ return {"answer": "No relevant KB found.", "confidence": 0.0}
105
+
106
+ return {
107
+ "answer": results['documents'][0][0],
108
+ "confidence": results['distances'][0][0] if results.get('distances') else 0.0,
109
+ "metadata": results['metadatas'][0][0] if results['metadatas'][0] else {}
110
+ }
111
+
112
+ # -------------------------------
113
+ # Gemini LLM Wrapper
114
+ # -------------------------------
115
+ class GeminiLLM:
116
+ def __init__(self, api_key=GEMINI_API_KEY):
117
+ self.api_key = api_key
118
+ self.api_url = GEMINI_API_URL
119
+
120
+ def __call__(self, prompt: str):
121
+ if not self.api_key:
122
+ return {"text": "⚠️ Gemini API key not set."}
123
+ payload = {"contents": [{"parts": [{"text": prompt}]}]}
124
+ headers = {"Authorization": f"Bearer {self.api_key}"}
125
+ try:
126
+ resp = requests.post(self.api_url, json=payload, headers=headers)
127
+ resp.raise_for_status()
128
+ data = resp.json()
129
+ text = data.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "")
130
+ return text
131
+ except:
132
+ return "⚠️ Gemini API call failed."
133
+
134
+ # -------------------------------
135
+ # Define LangChain Tools
136
+ # -------------------------------
137
+ tools = [
138
+ Tool(
139
+ name="TicketClassifier",
140
+ func=lambda text: classify_ticket(text),
141
+ description="Classifies a ticket into impact, urgency, and type. Mandatory tool."
142
+ ),
143
+ Tool(
144
+ name="RoutingTool",
145
+ func=lambda text: call_routing(text),
146
+ description="Assigns a department for the ticket via Space 2. Mandatory tool."
147
+ ),
148
+ Tool(
149
+ name="KnowledgeBaseTool",
150
+ func=lambda text: query_kb(text)["answer"],
151
+ description="Searches KB for relevant solution. Returns answer text."
152
+ )
153
+ ]
154
+
155
+ # -------------------------------
156
+ # Initialize Memory
157
+ # -------------------------------
158
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
159
+
160
+ # -------------------------------
161
+ # Initialize Agent
162
+ # -------------------------------
163
+ agent_executor = initialize_agent(
164
+ tools=tools,
165
+ llm=GeminiLLM(),
166
+ agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
167
+ memory=memory,
168
+ verbose=False
169
+ )
170
+
171
+ # -------------------------------
172
+ # Process Ticket Function
173
+ # -------------------------------
174
+ def process_ticket_langchain(ticket_text):
175
+ reasoning_trace = []
176
+
177
+ # Step 1: Classifier
178
+ classification = classify_ticket(ticket_text)
179
+ reasoning_trace.append(f"[Classifier] Impact: {classification['impact']}, Urgency: {classification['urgency']}, Type: {classification['type']}")
180
+
181
+ # Step 2: Routing
182
+ department = call_routing(ticket_text)
183
+ reasoning_trace.append(f"[Routing] Assigned Department: {department}")
184
+
185
+ # Step 3: KB Search
186
+ kb_result = query_kb(ticket_text)
187
+ reasoning_trace.append(f"[KB Search] Top answer: '{kb_result['answer']}' (confidence: {kb_result['confidence']})")
188
+
189
+ # Step 4: Decision KB vs LLM
190
+ if kb_result["confidence"] >= 0.75:
191
+ final_answer = kb_result["answer"]
192
+ status = "resolved"
193
+ reasoning_trace.append("[Decision] KB confidence high → ticket resolved via KB.")
194
+ else:
195
+ llm_prompt = f"""
196
+ You are a professional IT helpdesk assistant.
197
+ A user submitted the following ticket: "{ticket_text}"
198
+ Ticket classification: {classification}
199
+ Assigned department: {department}
200
+ KB Search result: {kb_result['answer']} (confidence: {kb_result['confidence']})
201
+
202
+ Provide a professional and descriptive solution or guidance based on this information.
203
+ """
204
+ final_answer = GeminiLLM()(llm_prompt)
205
+ status = "escalated"
206
+ reasoning_trace.append("[Decision] KB confidence low → ticket escalated via Gemini LLM.")
207
+
208
+ return {
209
+ "status": status,
210
+ "classification": classification,
211
+ "department": department,
212
+ "answer": final_answer,
213
+ "reasoning_trace": reasoning_trace
214
+ }
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from agent_langchain import process_ticket_langchain, classify_ticket, call_routing, kb_collection
4
+ import chromadb
5
+ from chromadb.config import Settings
6
+ from chromadb.utils import embedding_functions
7
+ import json
8
+ import os
9
+
10
+ app = FastAPI(title="Smart Helpdesk AI Agent LangChain")
11
+
12
+ # -------------------------------
13
+ # Request Models
14
+ # -------------------------------
15
+ class TicketRequest(BaseModel):
16
+ text: str
17
+ user_email: str = None
18
+
19
+ class SetupRequest(BaseModel):
20
+ kb_file: str # path to KB.json
21
+
22
+ # -------------------------------
23
+ # KB Setup Endpoint
24
+ # -------------------------------
25
+ @app.post("/setup")
26
+ async def setup_endpoint(req: SetupRequest):
27
+ """Embed KB.json and store in ChromaDB"""
28
+ global kb_collection
29
+ if not os.path.exists(req.kb_file):
30
+ raise HTTPException(status_code=404, detail="KB.json file not found")
31
+
32
+ # Load KB
33
+ with open(req.kb_file, "r") as f:
34
+ kb_data = json.load(f)
35
+
36
+ # Create ChromaDB collection if not exists
37
+ chroma_client = chromadb.Client(Settings(
38
+ chroma_db_impl="duckdb+parquet",
39
+ persist_directory="/data/chroma_db"
40
+ ))
41
+
42
+ try:
43
+ kb_collection = chroma_client.get_collection("kb_collection")
44
+ except:
45
+ kb_collection = chroma_client.create_collection("kb_collection")
46
+
47
+ # Setup embedding function
48
+ embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
49
+ model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1"
50
+ )
51
+
52
+ # Add KB entries
53
+ for entry in kb_data["knowledge_base"]:
54
+ kb_collection.add(
55
+ documents=[entry["answer"]],
56
+ metadatas=[{
57
+ "id": entry["id"],
58
+ "category": entry.get("category", ""),
59
+ "question_variations": entry.get("question_variations", []),
60
+ "keywords": entry.get("keywords", [])
61
+ }],
62
+ ids=[entry["id"]],
63
+ embedding_function=embedding_func
64
+ )
65
+
66
+ kb_collection.persist()
67
+ return {"status": "KB embedded and stored successfully"}
68
+
69
+ # -------------------------------
70
+ # Step-by-Step Endpoints
71
+ # -------------------------------
72
+
73
+ @app.post("/classify")
74
+ async def classify_endpoint(ticket: TicketRequest):
75
+ """Classify the ticket (impact, urgency, type)"""
76
+ classification = classify_ticket(ticket.text)
77
+ return {"classification": classification}
78
+
79
+ @app.post("/route")
80
+ async def route_endpoint(ticket: TicketRequest):
81
+ """Route the ticket to department (Space 2)"""
82
+ department = call_routing(ticket.text)
83
+ return {"department": department}
84
+
85
+ @app.post("/kb_query")
86
+ async def kb_query_endpoint(ticket: TicketRequest):
87
+ """Query KB directly"""
88
+ if not kb_collection:
89
+ raise HTTPException(status_code=400, detail="KB not set up. Call /setup first.")
90
+ result = kb_collection.query(query_texts=[ticket.text], n_results=1)
91
+ if not result or len(result['documents'][0]) == 0:
92
+ return {"answer": "No relevant KB found."}
93
+ return {"answer": result['documents'][0][0], "confidence": result['distances'][0][0] if result.get('distances') else 0.0}
94
+
95
+ # -------------------------------
96
+ # Full Ticket Orchestration
97
+ # -------------------------------
98
+ @app.post("/orchestrate")
99
+ async def orchestrate_endpoint(ticket: TicketRequest):
100
+ """Full ticket orchestration via LangChain agent with nicely formatted reasoning trace"""
101
+ result = process_ticket_langchain(ticket.text)
102
+
103
+ # Format reasoning trace for readability
104
+ formatted_trace = [{"step": idx + 1, "description": line} for idx, line in enumerate(result.get("reasoning_trace", []))]
105
+
106
+ response = {
107
+ "status": result["status"],
108
+ "classification": result["classification"],
109
+ "department": result["department"],
110
+ "answer": result["answer"],
111
+ "reasoning_trace": formatted_trace
112
+ }
113
+
114
+ return response
115
+
116
+ # -------------------------------
117
+ # Health Check
118
+ # -------------------------------
119
+ @app.get("/health")
120
+ async def health():
121
+ return {"status": "ok"}
main.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from agent_langchain import process_ticket_langchain, classify_ticket, call_routing, kb_collection
4
+ import chromadb
5
+ from chromadb.config import Settings
6
+ from chromadb.utils import embedding_functions
7
+ import json
8
+ import os
9
+
10
+ app = FastAPI(title="Smart Helpdesk AI Agent LangChain")
11
+
12
+ # -------------------------------
13
+ # Request Models
14
+ # -------------------------------
15
+ class TicketRequest(BaseModel):
16
+ text: str
17
+ user_email: str = None
18
+
19
+ class SetupRequest(BaseModel):
20
+ kb_file: str # path to KB.json
21
+
22
+ # -------------------------------
23
+ # KB Setup Endpoint
24
+ # -------------------------------
25
+ @app.post("/setup")
26
+ async def setup_endpoint(req: SetupRequest):
27
+ """Embed KB.json and store in ChromaDB"""
28
+ global kb_collection
29
+ if not os.path.exists(req.kb_file):
30
+ raise HTTPException(status_code=404, detail="KB.json file not found")
31
+
32
+ # Load KB
33
+ with open(req.kb_file, "r") as f:
34
+ kb_data = json.load(f)
35
+
36
+ # Create ChromaDB collection if not exists
37
+ chroma_client = chromadb.Client(Settings(
38
+ chroma_db_impl="duckdb+parquet",
39
+ persist_directory="/data/chroma_db"
40
+ ))
41
+
42
+ try:
43
+ kb_collection = chroma_client.get_collection("kb_collection")
44
+ except:
45
+ kb_collection = chroma_client.create_collection("kb_collection")
46
+
47
+ # Setup embedding function
48
+ embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
49
+ model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1"
50
+ )
51
+
52
+ # Add KB entries
53
+ for entry in kb_data["knowledge_base"]:
54
+ kb_collection.add(
55
+ documents=[entry["answer"]],
56
+ metadatas=[{
57
+ "id": entry["id"],
58
+ "category": entry.get("category", ""),
59
+ "question_variations": entry.get("question_variations", []),
60
+ "keywords": entry.get("keywords", [])
61
+ }],
62
+ ids=[entry["id"]],
63
+ embedding_function=embedding_func
64
+ )
65
+
66
+ kb_collection.persist()
67
+ return {"status": "KB embedded and stored successfully"}
68
+
69
+ # -------------------------------
70
+ # Step-by-Step Endpoints
71
+ # -------------------------------
72
+
73
+ @app.post("/classify")
74
+ async def classify_endpoint(ticket: TicketRequest):
75
+ """Classify the ticket (impact, urgency, type)"""
76
+ classification = classify_ticket(ticket.text)
77
+ return {"classification": classification}
78
+
79
+ @app.post("/route")
80
+ async def route_endpoint(ticket: TicketRequest):
81
+ """Route the ticket to department (Space 2)"""
82
+ department = call_routing(ticket.text)
83
+ return {"department": department}
84
+
85
+ @app.post("/kb_query")
86
+ async def kb_query_endpoint(ticket: TicketRequest):
87
+ """Query KB directly"""
88
+ if not kb_collection:
89
+ raise HTTPException(status_code=400, detail="KB not set up. Call /setup first.")
90
+ result = kb_collection.query(query_texts=[ticket.text], n_results=1)
91
+ if not result or len(result['documents'][0]) == 0:
92
+ return {"answer": "No relevant KB found."}
93
+ return {"answer": result['documents'][0][0], "confidence": result['distances'][0][0] if result.get('distances') else 0.0}
94
+
95
+ # -------------------------------
96
+ # Full Ticket Orchestration
97
+ # -------------------------------
98
+ @app.post("/orchestrate")
99
+ async def orchestrate_endpoint(ticket: TicketRequest):
100
+ """Full ticket orchestration via LangChain agent with nicely formatted reasoning trace"""
101
+ result = process_ticket_langchain(ticket.text)
102
+
103
+ # Format reasoning trace for readability
104
+ formatted_trace = [{"step": idx + 1, "description": line} for idx, line in enumerate(result.get("reasoning_trace", []))]
105
+
106
+ response = {
107
+ "status": result["status"],
108
+ "classification": result["classification"],
109
+ "department": result["department"],
110
+ "answer": result["answer"],
111
+ "reasoning_trace": formatted_trace
112
+ }
113
+
114
+ return response
115
+
116
+ # -------------------------------
117
+ # Health Check
118
+ # -------------------------------
119
+ @app.get("/health")
120
+ async def health():
121
+ return {"status": "ok"}
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.109.1
2
+ uvicorn[standard]==0.23.2
3
+ transformers==4.34.0
4
+ torch==2.2.0
5
+ sentence-transformers==2.2.2
6
+ requests==2.31.0
7
+ pydantic==2.6.1
8
+ chromadb==0.4.4
9
+ langchain==0.1.0
10
+ protobuf==4.23.4
11
+ accelerate==0.23.0