Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
##
|
| 2 |
import torch
|
| 3 |
import mauve
|
| 4 |
from sacrebleu import corpus_bleu
|
|
@@ -258,52 +257,158 @@ Answer:
|
|
| 258 |
|
| 259 |
llm_chain = LLMChain(llm=model, prompt=prompt)
|
| 260 |
|
| 261 |
-
# ------------------ RAG
|
| 262 |
-
def
|
|
|
|
| 263 |
retrieved_context = get_retrieved_context(question)
|
| 264 |
full_response = llm_chain.invoke({
|
| 265 |
"context": retrieved_context,
|
| 266 |
"question": question
|
| 267 |
})["text"].strip()
|
|
|
|
| 268 |
if "Answer:" in full_response:
|
| 269 |
full_response = full_response.split("Answer:", 1)[-1].strip()
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
# Stream word by word
|
|
|
|
| 272 |
current_text = ""
|
| 273 |
-
for word in
|
| 274 |
current_text += word + " "
|
| 275 |
yield current_text
|
| 276 |
-
|
| 277 |
-
# Evaluate after streaming
|
| 278 |
-
evaluator.evaluate_all(question, full_response, retrieved_context)
|
| 279 |
|
| 280 |
# ------------------ Gradio UI ------------------
|
| 281 |
-
with gr.Blocks() as demo:
|
| 282 |
-
gr.Markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
usage_counter = gr.State(value=0)
|
| 284 |
session_start = gr.State(value=datetime.now().isoformat())
|
|
|
|
| 285 |
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
count += 1
|
| 293 |
with mlflow.start_run(run_name=f"User-Interaction-{count}", nested=True):
|
| 294 |
mlflow.log_param("question", question)
|
| 295 |
mlflow.log_param("session_start", session_start)
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
|
|
|
|
|
|
| 299 |
mlflow.log_metric("total_queries", count)
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
ask_button.click(
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
inputs=[question_input, usage_counter, session_start, feedback],
|
| 305 |
-
outputs=[answer_output, usage_counter, session_start]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
)
|
| 307 |
|
| 308 |
if __name__ == "__main__":
|
| 309 |
-
demo.launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import mauve
|
| 3 |
from sacrebleu import corpus_bleu
|
|
|
|
| 257 |
|
| 258 |
llm_chain = LLMChain(llm=model, prompt=prompt)
|
| 259 |
|
| 260 |
+
# ------------------ RAG Pipeline ------------------
|
| 261 |
+
def get_rag_response(question):
|
| 262 |
+
"""Get the complete RAG response without streaming"""
|
| 263 |
retrieved_context = get_retrieved_context(question)
|
| 264 |
full_response = llm_chain.invoke({
|
| 265 |
"context": retrieved_context,
|
| 266 |
"question": question
|
| 267 |
})["text"].strip()
|
| 268 |
+
|
| 269 |
if "Answer:" in full_response:
|
| 270 |
full_response = full_response.split("Answer:", 1)[-1].strip()
|
| 271 |
|
| 272 |
+
return full_response, retrieved_context
|
| 273 |
+
|
| 274 |
+
def rag_pipeline_stream(question):
|
| 275 |
+
"""Streaming version of RAG pipeline"""
|
| 276 |
+
full_response, _ = get_rag_response(question)
|
| 277 |
+
|
| 278 |
# Stream word by word
|
| 279 |
+
words = full_response.split()
|
| 280 |
current_text = ""
|
| 281 |
+
for word in words:
|
| 282 |
current_text += word + " "
|
| 283 |
yield current_text
|
| 284 |
+
time.sleep(0.05) # Adjust speed as needed
|
|
|
|
|
|
|
| 285 |
|
| 286 |
# ------------------ Gradio UI ------------------
|
| 287 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 288 |
+
gr.Markdown("""
|
| 289 |
+
# π Maintenance AI Assistant
|
| 290 |
+
*Your intelligent companion for maintenance queries and troubleshooting*
|
| 291 |
+
""")
|
| 292 |
+
|
| 293 |
usage_counter = gr.State(value=0)
|
| 294 |
session_start = gr.State(value=datetime.now().isoformat())
|
| 295 |
+
current_response = gr.State(value="") # Store current response for evaluation
|
| 296 |
|
| 297 |
+
with gr.Row():
|
| 298 |
+
with gr.Column(scale=1):
|
| 299 |
+
gr.Markdown("### π¬ Chat Interface")
|
| 300 |
+
question_input = gr.Textbox(
|
| 301 |
+
label="Ask your maintenance question",
|
| 302 |
+
placeholder="e.g., How do I troubleshoot a leaking valve?",
|
| 303 |
+
lines=2
|
| 304 |
+
)
|
| 305 |
+
ask_button = gr.Button("Get Answer π", variant="primary")
|
| 306 |
+
|
| 307 |
+
feedback = gr.Radio(
|
| 308 |
+
["Helpful", "Not Helpful"],
|
| 309 |
+
label="Was this response helpful?",
|
| 310 |
+
info="Your feedback helps improve the system"
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
gr.Markdown("### π Evaluation Metrics")
|
| 314 |
+
metrics_output = gr.JSON(label="Quality Metrics", visible=False)
|
| 315 |
+
|
| 316 |
+
with gr.Column(scale=1):
|
| 317 |
+
gr.Markdown("### π€ AI Response")
|
| 318 |
+
answer_output = gr.Textbox(
|
| 319 |
+
label="Response",
|
| 320 |
+
lines=6,
|
| 321 |
+
interactive=False,
|
| 322 |
+
show_copy_button=True
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
with gr.Row():
|
| 326 |
+
clear_btn = gr.Button("Clear Chat ποΈ")
|
| 327 |
+
evaluate_btn = gr.Button("Show Metrics π", variant="secondary")
|
| 328 |
+
|
| 329 |
+
def track_usage(question, count, session_start, feedback_value=None):
|
| 330 |
+
"""Track usage and get response"""
|
| 331 |
count += 1
|
| 332 |
with mlflow.start_run(run_name=f"User-Interaction-{count}", nested=True):
|
| 333 |
mlflow.log_param("question", question)
|
| 334 |
mlflow.log_param("session_start", session_start)
|
| 335 |
+
mlflow.log_param("user_feedback", feedback_value or "No feedback")
|
| 336 |
+
|
| 337 |
+
if feedback_value:
|
| 338 |
+
mlflow.log_metric("helpful_responses", 1 if feedback_value == "Helpful" else 0)
|
| 339 |
+
|
| 340 |
mlflow.log_metric("total_queries", count)
|
| 341 |
+
|
| 342 |
+
# Get response and context
|
| 343 |
+
response, context = get_rag_response(question)
|
| 344 |
+
|
| 345 |
+
# Log response metrics
|
| 346 |
+
mlflow.log_metric("response_length", len(response))
|
| 347 |
+
mlflow.log_metric("response_tokens", len(response.split()))
|
| 348 |
+
|
| 349 |
+
return response, count, session_start, response
|
| 350 |
|
| 351 |
+
def evaluate_response(question, response):
|
| 352 |
+
"""Evaluate the response and return metrics"""
|
| 353 |
+
if not question or not response:
|
| 354 |
+
return gr.update(value={}, visible=False)
|
| 355 |
+
|
| 356 |
+
try:
|
| 357 |
+
context = get_retrieved_context(question)
|
| 358 |
+
metrics = evaluator.evaluate_all(question, response, context)
|
| 359 |
+
|
| 360 |
+
# Log metrics to MLflow
|
| 361 |
+
for metric_name, metric_value in metrics.items():
|
| 362 |
+
if isinstance(metric_value, (int, float)):
|
| 363 |
+
mlflow.log_metric(metric_name, metric_value)
|
| 364 |
+
|
| 365 |
+
return gr.update(value=metrics, visible=True)
|
| 366 |
+
except Exception as e:
|
| 367 |
+
print(f"Evaluation error: {e}")
|
| 368 |
+
return gr.update(value={"error": str(e)}, visible=True)
|
| 369 |
+
|
| 370 |
+
def clear_chat():
|
| 371 |
+
"""Clear the chat interface"""
|
| 372 |
+
return "", "", gr.update(visible=False)
|
| 373 |
+
|
| 374 |
+
# Main interaction flow
|
| 375 |
ask_button.click(
|
| 376 |
+
fn=lambda: ("", gr.update(visible=False)), # Clear previous metrics
|
| 377 |
+
outputs=[answer_output, metrics_output]
|
| 378 |
+
).then(
|
| 379 |
+
fn=rag_pipeline_stream,
|
| 380 |
+
inputs=[question_input],
|
| 381 |
+
outputs=[answer_output]
|
| 382 |
+
).then(
|
| 383 |
+
fn=track_usage,
|
| 384 |
inputs=[question_input, usage_counter, session_start, feedback],
|
| 385 |
+
outputs=[answer_output, usage_counter, session_start, current_response]
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# Evaluation flow
|
| 389 |
+
evaluate_btn.click(
|
| 390 |
+
fn=evaluate_response,
|
| 391 |
+
inputs=[question_input, current_response],
|
| 392 |
+
outputs=[metrics_output]
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Clear chat
|
| 396 |
+
clear_btn.click(
|
| 397 |
+
fn=clear_chat,
|
| 398 |
+
outputs=[question_input, answer_output, metrics_output]
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# Feedback handling
|
| 402 |
+
feedback.change(
|
| 403 |
+
fn=lambda feedback_val: mlflow.log_metric("user_feedback_score", 1 if feedback_val == "Helpful" else 0),
|
| 404 |
+
inputs=[feedback],
|
| 405 |
+
outputs=[]
|
| 406 |
)
|
| 407 |
|
| 408 |
if __name__ == "__main__":
|
| 409 |
+
demo.launch(
|
| 410 |
+
server_name="0.0.0.0",
|
| 411 |
+
server_port=7860,
|
| 412 |
+
share=True,
|
| 413 |
+
show_error=True
|
| 414 |
+
)
|