pradeep4321 commited on
Commit
fccb3d2
Β·
verified Β·
1 Parent(s): 6899cb0

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +54 -27
src/streamlit_app.py CHANGED
@@ -11,7 +11,7 @@ from huggingface_hub import InferenceClient
11
  # CONFIG
12
  # ==============================
13
  st.set_page_config(page_title="Company ChatGPT", layout="wide")
14
- st.title("🏒 Company AI Assistant")
15
 
16
  # ==============================
17
  # LOAD MODELS
@@ -19,9 +19,15 @@ st.title("🏒 Company AI Assistant")
19
  @st.cache_resource
20
  def load_models():
21
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
 
 
22
  llm = InferenceClient(
23
  model="meta-llama/Meta-Llama-3-8B-Instruct",
24
- token=os.environ.get("HF_TOKEN")
25
  )
26
  return embed_model, llm
27
 
@@ -32,11 +38,20 @@ embed_model, llm = load_models()
32
  # ==============================
33
  @st.cache_data
34
  def load_data():
35
- df = pd.read_csv("src/company_sample.csv")
 
 
 
 
36
  return df
37
 
38
  df = load_data()
39
- documents = df["text"].tolist()
 
 
 
 
 
40
 
41
  # ==============================
42
  # CREATE VECTOR DB
@@ -46,9 +61,9 @@ def create_faiss(docs):
46
  embeddings = embed_model.encode(docs)
47
  index = faiss.IndexFlatL2(embeddings.shape[1])
48
  index.add(np.array(embeddings))
49
- return index, embeddings
50
 
51
- index, doc_embeddings = create_faiss(documents)
52
 
53
  # ==============================
54
  # RETRIEVAL FUNCTION
@@ -56,7 +71,7 @@ index, doc_embeddings = create_faiss(documents)
56
  def retrieve(query, top_k=3):
57
  q_emb = embed_model.encode([query])
58
  D, I = index.search(np.array(q_emb), top_k)
59
- return [documents[i] for i in I[0]]
60
 
61
  # ==============================
62
  # CHAT HISTORY
@@ -64,7 +79,6 @@ def retrieve(query, top_k=3):
64
  if "messages" not in st.session_state:
65
  st.session_state.messages = []
66
 
67
- # Display history
68
  for msg in st.session_state.messages:
69
  st.chat_message(msg["role"]).write(msg["content"])
70
 
@@ -77,29 +91,42 @@ if query:
77
  st.session_state.messages.append({"role": "user", "content": query})
78
  st.chat_message("user").write(query)
79
 
80
- # πŸ” Retrieve relevant docs
81
  context_docs = retrieve(query)
82
- context = "\n".join(context_docs)
83
-
84
- # 🧠 Build prompt
85
- prompt = f"""
86
- You are a company assistant. Answer ONLY based on the context below.
87
-
 
 
 
 
 
 
 
 
 
88
  Context:
89
  {context}
90
 
91
  Question:
92
  {query}
93
-
94
- Answer:
95
  """
96
-
97
- # πŸ€– LLM Call
98
- response = llm.text_generation(
99
- prompt,
100
- max_new_tokens=200,
101
- temperature=0.5
102
- )
103
-
104
- st.session_state.messages.append({"role": "assistant", "content": response})
105
- st.chat_message("assistant").write(response)
 
 
 
 
 
 
 
11
  # CONFIG
12
  # ==============================
13
  st.set_page_config(page_title="Company ChatGPT", layout="wide")
14
+ st.title("🏒 Company AI Assistant (RAG Powered)")
15
 
16
  # ==============================
17
  # LOAD MODELS
 
19
  @st.cache_resource
20
  def load_models():
21
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
22
+
23
+ HF_TOKEN = os.environ.get("HF_TOKEN")
24
+ if not HF_TOKEN:
25
+ st.error("❌ Please add HF_TOKEN in Hugging Face Secrets")
26
+ st.stop()
27
+
28
  llm = InferenceClient(
29
  model="meta-llama/Meta-Llama-3-8B-Instruct",
30
+ token=HF_TOKEN
31
  )
32
  return embed_model, llm
33
 
 
38
  # ==============================
39
  @st.cache_data
40
  def load_data():
41
+ path = "src/company_sample.csv"
42
+ if not os.path.exists(path):
43
+ st.error(f"❌ File not found: {path}")
44
+ st.stop()
45
+ df = pd.read_csv(path)
46
  return df
47
 
48
  df = load_data()
49
+
50
+ if "text" not in df.columns:
51
+ st.error("❌ CSV must contain 'text' column")
52
+ st.stop()
53
+
54
+ documents = df["text"].fillna("").tolist()
55
 
56
  # ==============================
57
  # CREATE VECTOR DB
 
61
  embeddings = embed_model.encode(docs)
62
  index = faiss.IndexFlatL2(embeddings.shape[1])
63
  index.add(np.array(embeddings))
64
+ return index
65
 
66
+ index = create_faiss(documents)
67
 
68
  # ==============================
69
  # RETRIEVAL FUNCTION
 
71
  def retrieve(query, top_k=3):
72
  q_emb = embed_model.encode([query])
73
  D, I = index.search(np.array(q_emb), top_k)
74
+ return [documents[i] for i in I[0] if i < len(documents)]
75
 
76
  # ==============================
77
  # CHAT HISTORY
 
79
  if "messages" not in st.session_state:
80
  st.session_state.messages = []
81
 
 
82
  for msg in st.session_state.messages:
83
  st.chat_message(msg["role"]).write(msg["content"])
84
 
 
91
  st.session_state.messages.append({"role": "user", "content": query})
92
  st.chat_message("user").write(query)
93
 
94
+ # πŸ” Retrieve context
95
  context_docs = retrieve(query)
96
+ context = "\n\n".join(context_docs)
97
+
98
+ # ==============================
99
+ # πŸ€– LLM CALL (FIXED)
100
+ # ==============================
101
+ try:
102
+ response = llm.chat_completion(
103
+ messages=[
104
+ {
105
+ "role": "system",
106
+ "content": "You are a company assistant. Answer ONLY from given context. If not found, say 'Not available in company data.'"
107
+ },
108
+ {
109
+ "role": "user",
110
+ "content": f"""
111
  Context:
112
  {context}
113
 
114
  Question:
115
  {query}
 
 
116
  """
117
+ }
118
+ ],
119
+ max_tokens=200,
120
+ temperature=0.5
121
+ )
122
+
123
+ answer = response.choices[0].message.content
124
+
125
+ except Exception as e:
126
+ answer = f"❌ Error: {str(e)}"
127
+
128
+ # ==============================
129
+ # DISPLAY RESPONSE
130
+ # ==============================
131
+ st.session_state.messages.append({"role": "assistant", "content": answer})
132
+ st.chat_message("assistant").write(answer)