Shauryaaa05 commited on
Commit
5e4f3a8
·
verified ·
1 Parent(s): 517f7f1

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +78 -32
src/streamlit_app.py CHANGED
@@ -1,40 +1,86 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
 
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
 
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
 
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ import faiss
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
8
+ import pickle
9
 
10
+ st.set_page_config(page_title="AutoResolve Agent", page_icon="🤖", layout="centered")
 
 
 
 
 
11
 
12
+ st.title("🤖 AutoResolve: IT Support Agent")
13
+ st.markdown("This end-to-end LLM Agent classifies your IT issue, retrieves the relevant enterprise policy, and generates a solution.")
14
 
15
+ # --- 1. Load Models (Cached so they only load once) ---
16
+ @st.cache_resource
17
+ def load_pipeline():
18
+ # 1. Load DistilBERT Classifier
19
+ # Note: You must upload your 'autoresolve_distilbert_final' folder to the HF space!
20
+ distil_tokenizer = DistilBertTokenizerFast.from_pretrained("./autoresolve_distilbert_final")
21
+ distil_model = DistilBertForSequenceClassification.from_pretrained("./autoresolve_distilbert_final")
22
+
23
+ # 2. Knowledge Base & Retriever
24
+ kb = [
25
+ "Refund Policy: Customers are entitled to a full refund within 30 days of purchase. To process, verify the order number and issue the refund to the original payment method.",
26
+ "Order Tracking: To locate an order, query the shipping database using the 10-digit order number. If the status is 'Dispatched', provide the user with the carrier tracking link.",
27
+ "Password Recovery: If a user cannot log in, send a secure password reset link to their registered email address. Ensure they check their spam folder.",
28
+ "Payment Issues: If a transfer or payment fails, verify if the credit card is expired or if the anti-fraud system flagged the transaction. Recommend trying a different payment method."
29
+ ]
30
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
31
+ kb_embeddings = embedder.encode(kb, convert_to_numpy=True)
32
+ index = faiss.IndexFlatL2(kb_embeddings.shape[1])
33
+ index.add(kb_embeddings)
34
+
35
+ # 3. Load Generative LLM (CPU mode)
36
+ llama_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
37
+ llama_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", device_map="cpu")
38
+
39
+ return distil_tokenizer, distil_model, kb, embedder, index, llama_tokenizer, llama_model
40
 
41
+ with st.spinner("Loading AI Models... (This takes about 60 seconds on initial boot)"):
42
+ distil_tokenizer, distil_model, knowledge_base, embedder, index, llama_tokenizer, llama_model = load_pipeline()
 
43
 
44
+ # Define the intents manually to avoid needing the full dataset for the LabelEncoder
45
+ INTENTS = ['cancel_order', 'change_order', 'change_shipping_address', 'check_cancellation_fee', 'check_invoice', 'check_payment_methods', 'check_refund_policy', 'complaint', 'contact_customer_service', 'contact_human_agent', 'create_account', 'delete_account', 'delivery_options', 'delivery_period', 'edit_account', 'get_invoice', 'get_refund', 'newsletter_subscription', 'payment_issue', 'place_order', 'recover_password', 'registration_problems', 'review', 'set_up_shipping_address', 'switch_account', 'track_order', 'track_refund']
46
 
47
+ # --- 2. The User Interface ---
48
+ user_query = st.text_input("Describe your IT or Support issue:", placeholder="e.g., am I entitled to a reimbursement?")
 
 
 
 
49
 
50
+ if st.button("Submit Ticket"):
51
+ if user_query:
52
+ with st.spinner("Processing..."):
53
+ # Step A: Intent Classification
54
+ inputs = distil_tokenizer(user_query, return_tensors="pt", truncation=True, padding=True)
55
+ with torch.no_grad():
56
+ logits = distil_model(**inputs).logits
57
+ predicted_class_id = logits.argmax().item()
58
+ predicted_intent = INTENTS[predicted_class_id]
59
+
60
+ st.success(f"**Intent Classified:** `{predicted_intent}`")
61
+
62
+ # Step B: Retrieval
63
+ query_vector = embedder.encode([user_query], convert_to_numpy=True)
64
+ distances, indices = index.search(query_vector, 1)
65
+ retrieved_doc = knowledge_base[indices[0][0]]
66
+
67
+ st.info(f"**Retrieved Knowledge Base Document:** {retrieved_doc}")
68
+
69
+ # Step C: Generation
70
+ prompt = f"""<|im_start|>system
71
+ You are AutoResolve, an IT support agent. Answer the user's query using ONLY the provided IT Document. Be polite, concise, and professional.<|im_end|>
72
+ <|im_start|>user
73
+ User Query: {user_query}
74
+ IT Document: {retrieved_doc}<|im_end|>
75
+ <|im_start|>assistant
76
+ """
77
+ gen_inputs = llama_tokenizer(prompt, return_tensors="pt")
78
+ outputs = llama_model.generate(**gen_inputs, max_new_tokens=150, temperature=0.1, pad_token_id=llama_tokenizer.eos_token_id)
79
+
80
+ full_response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
81
+ final_answer = full_response.split("assistant\n")[-1].strip()
82
+
83
+ st.write("### 💬 AutoResolve Agent Response:")
84
+ st.write(f"> {final_answer}")
85
+ else:
86
+ st.warning("Please enter a query first.")