Prathamesh1420 commited on
Commit
23f378c
·
verified ·
1 Parent(s): dc64232

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -359
app.py CHANGED
@@ -1,362 +1,3 @@
1
- '''
2
- ####
3
- import os
4
- import gradio as gr
5
- import requests
6
- from pinecone import Pinecone
7
- from langchain.prompts import PromptTemplate
8
- from langchain.chains.llm import LLMChain
9
- from langchain.llms.base import LLM
10
- from typing import Optional, List, Mapping, Any
11
- from langchain.embeddings import HuggingFaceEmbeddings
12
-
13
- # ----------- 1. Custom LLM to call your LitServe endpoint -----------
14
- class LitServeLLM(LLM):
15
- endpoint_url: str
16
-
17
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
18
- payload = {"prompt": prompt}
19
- response = requests.post(self.endpoint_url, json=payload)
20
- if response.status_code == 200:
21
- data = response.json()
22
- return data.get("response", "").strip()
23
- else:
24
- raise ValueError(f"Request failed: {response.status_code} {response.text}")
25
-
26
- @property
27
- def _identifying_params(self) -> Mapping[str, Any]:
28
- return {"endpoint_url": self.endpoint_url}
29
-
30
- @property
31
- def _llm_type(self) -> str:
32
- return "litserve_llm"
33
-
34
-
35
- # ----------- 2. Connect to Pinecone -----------
36
- PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
37
- pc = Pinecone(api_key=PINECONE_API_KEY)
38
- index = pc.Index("rag-granite-index")
39
-
40
- # ----------- 3. Load embedding model -----------
41
- embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
42
-
43
- # ----------- 4. Function to get top context from Pinecone -----------
44
- def get_retrieved_context(query: str, top_k=3):
45
- query_embedding = embeddings_model.embed_query(query)
46
- results = index.query(
47
- namespace="rag-ns",
48
- vector=query_embedding,
49
- top_k=top_k,
50
- include_metadata=True
51
- )
52
- context_parts = [match['metadata']['text'] for match in results['matches']]
53
- return "\n".join(context_parts)
54
-
55
- # ----------- 5. Create LLMChain with your model -----------
56
- model = LitServeLLM(
57
- endpoint_url="https://8001-01k2h9d9mervcmgfn66ybkpwvq.cloudspaces.litng.ai/predict"
58
- )
59
-
60
- prompt = PromptTemplate(
61
- input_variables=["context", "question"],
62
- template="""
63
- You are a smart assistant. Based on the provided context, answer the question in 1–2 lines only.
64
- If the context has more details, summarize it concisely.
65
-
66
- Context:
67
- {context}
68
-
69
- Question: {question}
70
-
71
- Answer:
72
- """
73
- )
74
-
75
- llm_chain = LLMChain(llm=model, prompt=prompt)
76
-
77
- # ----------- 6. Main RAG Function -----------
78
- def rag_pipeline(question):
79
- try:
80
- retrieved_context = get_retrieved_context(question)
81
- response = llm_chain.invoke({
82
- "context": retrieved_context,
83
- "question": question
84
- })["text"].strip()
85
-
86
- # Only keep what's after "Answer:"
87
- if "Answer:" in response:
88
- response = response.split("Answer:", 1)[-1].strip()
89
-
90
- return response
91
- except Exception as e:
92
- return f"Error: {str(e)}"
93
-
94
-
95
- # ----------- 7. Gradio UI -----------
96
- with gr.Blocks() as demo:
97
- gr.Markdown("# 🧠 RAG Chatbot (Pinecone + LitServe)")
98
- question_input = gr.Textbox(label="Ask your question here")
99
- answer_output = gr.Textbox(label="Answer")
100
- ask_button = gr.Button("Get Answer")
101
- ask_button.click(rag_pipeline, inputs=question_input, outputs=answer_output)
102
-
103
- if _name_ == "_main_":
104
- demo.launch()
105
- '''
106
-
107
-
108
-
109
- '''
110
- import os
111
- import gradio as gr
112
- import requests
113
- import mlflow
114
- import dagshub
115
- from pinecone import Pinecone
116
- from langchain.prompts import PromptTemplate
117
- from langchain.chains.llm import LLMChain
118
- from langchain.llms.base import LLM
119
- from typing import Optional, List, Mapping, Any
120
- import time
121
- from langchain_community.embeddings import HuggingFaceEmbeddings
122
- from dotenv import load_dotenv
123
- from datetime import datetime
124
-
125
- # Load environment variables
126
- pinecone_api_key = os.environ["PINECONE_API_KEY"]
127
-
128
- mlflow_tracking_uri = os.environ["MLFLOW_TRACKING_URI"]
129
-
130
- # ----------- DagsHub & MLflow Setup -----------
131
-
132
- dagshub.init(
133
- repo_owner='prathamesh.khade20',
134
- repo_name='Maintenance_AI_website',
135
- mlflow=True
136
- )
137
-
138
- mlflow.set_tracking_uri(mlflow_tracking_uri)
139
- mlflow.set_experiment("Maintenance-RAG-Chatbot")
140
- mlflow.langchain.autolog()
141
-
142
-
143
-
144
- # Initialize MLflow run for app configuration
145
- with mlflow.start_run(run_name=f"App-Config-{datetime.now().strftime('%Y%m%d-%H%M%S')}") as setup_run:
146
- # Log environment configuration
147
- mlflow.log_params({
148
- "pinecone_index": "rag-granite-index",
149
- "embedding_model": "all-MiniLM-L6-v2",
150
- "namespace": "rag-ns",
151
- "top_k": 3,
152
- "llm_endpoint": "https://8001-01k2h9d9mervcmgfn66ybkpwvq.cloudspaces.litng.ai/predict"
153
- })
154
-
155
- # Log important files as artifacts
156
-
157
- mlflow.log_text("""
158
- You are a smart assistant. Based on the provided context, answer the question in 1–2 lines only.
159
- If the context has more details, summarize it concisely.
160
- Context:
161
- {context}
162
- Question: {question}
163
- Answer:
164
- """, "artifacts/prompt_template.txt")
165
-
166
- # ----------- 1. Custom LLM for LitServe endpoint -----------
167
- class LitServeLLM(LLM):
168
- endpoint_url: str
169
-
170
- @mlflow.trace
171
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
172
- payload = {"prompt": prompt}
173
-
174
- with mlflow.start_span("lit_serve_request"):
175
- start_time = time.time()
176
- response = requests.post(self.endpoint_url, json=payload)
177
- latency = time.time() - start_time
178
-
179
- mlflow.log_metric("lit_serve_latency", latency)
180
-
181
- if response.status_code == 200:
182
- data = response.json()
183
- mlflow.log_metric("response_tokens", len(data.get("response", "").split()))
184
- return data.get("response", "").strip()
185
- else:
186
- mlflow.log_metric("request_errors", 1)
187
- error_info = {
188
- "status_code": response.status_code,
189
- "error": response.text,
190
- "timestamp": datetime.now().isoformat()
191
- }
192
- mlflow.log_dict(error_info, "artifacts/error_log.json")
193
- raise ValueError(f"Request failed: {response.status_code}")
194
-
195
- @property
196
- def _identifying_params(self) -> Mapping[str, Any]:
197
- return {"endpoint_url": self.endpoint_url}
198
-
199
- @property
200
- def _llm_type(self) -> str:
201
- return "litserve_llm"
202
-
203
- # ----------- 2. Pinecone Connection -----------
204
- @mlflow.trace
205
- def init_pinecone():
206
- PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
207
- pc = Pinecone(api_key=PINECONE_API_KEY)
208
- return pc.Index("rag-granite-index")
209
-
210
- index = init_pinecone()
211
-
212
- # ----------- 3. Embedding Model -----------
213
- embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
214
-
215
- # ----------- 4. Context Retrieval with Tracing -----------
216
- @mlflow.trace
217
- def get_retrieved_context(query: str, top_k=3):
218
- """Retrieve context from Pinecone with performance tracing"""
219
- with mlflow.start_span("embedding_generation"):
220
- start_time = time.time()
221
- query_embedding = embeddings_model.embed_query(query)
222
- mlflow.log_metric("embedding_latency", time.time() - start_time)
223
-
224
- with mlflow.start_span("pinecone_query"):
225
- start_time = time.time()
226
- results = index.query(
227
- namespace="rag-ns",
228
- vector=query_embedding,
229
- top_k=top_k,
230
- include_metadata=True
231
- )
232
- mlflow.log_metric("pinecone_latency", time.time() - start_time)
233
- mlflow.log_metric("retrieved_chunks", len(results['matches']))
234
-
235
- context_parts = [match['metadata']['text'] for match in results['matches']]
236
- return "\n".join(context_parts)
237
-
238
- # ----------- 5. LLM Chain Setup -----------
239
- model = LitServeLLM(
240
- endpoint_url="https://8001-01k2h9d9mervcmgfn66ybkpwvq.cloudspaces.litng.ai/predict"
241
- )
242
-
243
- prompt = PromptTemplate(
244
- input_variables=["context", "question"],
245
- template="""
246
- You are a smart assistant. Based on the provided context, answer the question in 1–2 lines only.
247
- If the context has more details, summarize it concisely.
248
- Context:
249
- {context}
250
- Question: {question}
251
- Answer:
252
- """
253
- )
254
-
255
- llm_chain = LLMChain(llm=model, prompt=prompt)
256
-
257
- # ----------- 6. RAG Pipeline with Full Tracing -----------
258
- @mlflow.trace
259
- def rag_pipeline(question):
260
- """End-to-end RAG pipeline with MLflow tracing"""
261
- try:
262
- # Start a new nested run for each query
263
- with mlflow.start_run(run_name=f"Query-{datetime.now().strftime('%H%M%S')}", nested=True):
264
- mlflow.log_param("user_question", question)
265
-
266
- # Retrieve context
267
- retrieved_context = get_retrieved_context(question)
268
- mlflow.log_text(retrieved_context, "artifacts/retrieved_context.txt")
269
-
270
- # Generate response
271
- start_time = time.time()
272
- response = llm_chain.invoke({
273
- "context": retrieved_context,
274
- "question": question
275
- })["text"].strip()
276
-
277
- # Clean response
278
- if "Answer:" in response:
279
- response = response.split("Answer:", 1)[-1].strip()
280
-
281
- # Log metrics
282
- mlflow.log_metric("response_latency", time.time() - start_time)
283
- mlflow.log_metric("response_length", len(response))
284
- mlflow.log_text(response, "artifacts/response.txt")
285
-
286
- return response
287
-
288
- except Exception as e:
289
- mlflow.log_metric("pipeline_errors", 1)
290
- error_info = {
291
- "error": str(e),
292
- "question": question,
293
- "timestamp": datetime.now().isoformat()
294
- }
295
- mlflow.log_dict(error_info, "artifacts/pipeline_errors.json")
296
- return f"Error: {str(e)}"
297
-
298
- # ----------- 7. Gradio UI with Enhanced Tracking -----------
299
- with gr.Blocks() as demo:
300
- gr.Markdown("# 🛠 Maintenance AI Assistant")
301
-
302
- # Track additional UI metrics
303
- usage_counter = gr.State(value=0)
304
- session_start = gr.State(value=datetime.now().isoformat())
305
-
306
- question_input = gr.Textbox(label="Ask your maintenance question")
307
- answer_output = gr.Textbox(label="AI Response")
308
- ask_button = gr.Button("Get Answer")
309
- feedback = gr.Radio(["Helpful", "Not Helpful"], label="Was this response helpful?")
310
-
311
- def track_usage(question, count, session_start, feedback=None):
312
- """Wrapper to track usage metrics with feedback"""
313
- count += 1
314
-
315
- # Start tracking context
316
- with mlflow.start_run(run_name=f"User-Interaction-{count}", nested=True):
317
- mlflow.log_param("question", question)
318
- mlflow.log_param("session_start", session_start)
319
-
320
- # Get response
321
- response = rag_pipeline(question)
322
-
323
- # Log feedback if provided
324
- if feedback:
325
- mlflow.log_param("user_feedback", feedback)
326
- mlflow.log_metric("helpful_responses", 1 if feedback == "Helpful" else 0)
327
-
328
- # Update metrics
329
- mlflow.log_metric("total_queries", count)
330
-
331
- return response, count, session_start
332
-
333
- ask_button.click(
334
- track_usage,
335
- inputs=[question_input, usage_counter, session_start],
336
- outputs=[answer_output, usage_counter, session_start]
337
- )
338
-
339
- feedback.change(
340
- track_usage,
341
- inputs=[question_input, usage_counter, session_start, feedback],
342
- outputs=[answer_output, usage_counter, session_start]
343
- )
344
-
345
- if _name_ == "_main_":
346
- # Log deployment information
347
- with mlflow.start_run(run_name="Deployment-Info"):
348
- mlflow.log_params({
349
- "app_version": "1.0.0",
350
- "deployment_platform": "Lightning AI",
351
- "deployment_time": datetime.now().isoformat(),
352
- "code_version": os.getenv("GIT_COMMIT", "dev")
353
- })
354
-
355
- # Start Gradio app
356
- demo.launch()
357
-
358
- '''
359
-
360
  import torch
361
  import mauve
362
  from sacrebleu import corpus_bleu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import mauve
3
  from sacrebleu import corpus_bleu