dxnxk commited on
Commit
9ca8b14
·
verified ·
1 Parent(s): 84d7796

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -6,35 +6,31 @@ from sentence_transformers import SentenceTransformer
6
  from huggingface_hub import InferenceClient
7
 
8
  # --- Load data ---
9
- df = pd.read_csv("tariff_codes.csv")
10
- descriptions = df["description"].astype(str).tolist()
11
- codes = df["code"].astype(str).tolist()
12
 
13
  # --- Create embeddings ---
14
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
15
  embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
16
 
17
- # --- FAISS index (cosine similarity = inner product on normalized vectors) ---
18
  dim = embeddings.shape[1]
19
  faiss.normalize_L2(embeddings)
20
  index = faiss.IndexFlatIP(dim)
21
  index.add(embeddings)
22
 
23
- # --- Hugging Face Inference API client ---
24
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
25
 
26
- # --- RAG response generation ---
27
  def generate_answer(user_query):
28
  query_embedding = embedding_model.encode([user_query], convert_to_numpy=True)
29
  faiss.normalize_L2(query_embedding)
30
  _, indices = index.search(query_embedding, k=5)
31
 
32
- retrieved_context = "\n".join([f"{codes[i]}: {descriptions[i]}" for i in indices[0]])
33
- prompt = f"""Here are some tariff code descriptions:
34
- {retrieved_context}
35
-
36
- Question: {user_query}
37
- Answer:"""
38
 
39
  response = client.text_generation(
40
  prompt,
@@ -44,7 +40,7 @@ Answer:"""
44
  )
45
  return response.strip()
46
 
47
- # --- Gradio Chat Interface ---
48
  gr.ChatInterface(
49
  fn=generate_answer,
50
  title="Tariff Code RAG Bot (FAISS + Inference API)"
 
6
  from huggingface_hub import InferenceClient
7
 
8
  # --- Load data ---
9
+ df = pd.read_csv("tariff_codes.csv", encoding="latin1", low_memory=False)
10
+ descriptions = df["Description"].astype(str).tolist()
11
+ codes = df["Code"].astype(str).tolist()
12
 
13
  # --- Create embeddings ---
14
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
15
  embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
16
 
17
+ # --- FAISS index (cosine similarity) ---
18
  dim = embeddings.shape[1]
19
  faiss.normalize_L2(embeddings)
20
  index = faiss.IndexFlatIP(dim)
21
  index.add(embeddings)
22
 
23
+ # --- Inference API ---
24
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
25
 
26
+ # --- RAG pipeline ---
27
  def generate_answer(user_query):
28
  query_embedding = embedding_model.encode([user_query], convert_to_numpy=True)
29
  faiss.normalize_L2(query_embedding)
30
  _, indices = index.search(query_embedding, k=5)
31
 
32
+ context = "\n".join([f"{codes[i]}: {descriptions[i]}" for i in indices[0]])
33
+ prompt = f"""Here are some tariff code descriptions:\n{context}\n\nQuestion: {user_query}\nAnswer:"""
 
 
 
 
34
 
35
  response = client.text_generation(
36
  prompt,
 
40
  )
41
  return response.strip()
42
 
43
+ # --- Gradio UI ---
44
  gr.ChatInterface(
45
  fn=generate_answer,
46
  title="Tariff Code RAG Bot (FAISS + Inference API)"