dxnxk commited on
Commit
bc1eeb8
·
1 Parent(s): 29ebbed
Files changed (1) hide show
  1. app.py +48 -27
app.py CHANGED
@@ -1,8 +1,9 @@
 
 
1
  import pandas as pd
2
  import numpy as np
3
- import gradio as gr
4
  import faiss
5
- import sys
6
  from sentence_transformers import SentenceTransformer
7
  from huggingface_hub import InferenceClient
8
 
@@ -12,40 +13,60 @@ df.columns = df.columns.str.strip()
12
  descriptions = df["brief_description"].astype(str).tolist()
13
  codes = df["hts8"].astype(str).tolist()
14
 
15
- # --- Create embeddings ---
16
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
17
- embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
18
-
19
- # --- FAISS index (cosine similarity) ---
20
- dim = embeddings.shape[1]
21
- faiss.normalize_L2(embeddings)
22
- index = faiss.IndexFlatIP(dim)
23
- index.add(embeddings)
24
 
25
- # --- Inference API ---
26
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
 
 
 
 
 
 
 
 
 
27
 
28
- # --- RAG pipeline ---
29
- def generate_answer(user_query):
30
- sys.stderr.write("=== generate_answer called ===\n")
31
- sys.stderr.flush()
32
 
33
- query_embedding = embedding_model.encode([user_query], convert_to_numpy=True)
 
 
34
  faiss.normalize_L2(query_embedding)
35
  _, indices = index.search(query_embedding, k=5)
36
 
37
  context = "\n".join([f"{codes[i]}: {descriptions[i]}" for i in indices[0]])
38
- prompt = f"""Here are some tariff code descriptions:\n{context}\n\nQuestion: {user_query}\nAnswer:"""
39
 
40
- sys.stderr.write(f"Prompt sent to model:\n{prompt}\n")
41
- sys.stderr.flush()
 
 
42
 
43
- response = client.text_generation(
44
- prompt,
45
- max_new_tokens=200,
 
 
 
 
 
 
 
 
 
 
 
46
  temperature=0.7,
47
- stop_sequences=["\n\n"]
48
- )
49
- return response.strip()
 
 
 
 
 
50
 
51
- gr.Interface(fn=generate_answer, inputs="text", outputs="text").launch()
 
 
1
+ import os
2
+ import sys
3
  import pandas as pd
4
  import numpy as np
 
5
  import faiss
6
+ import gradio as gr
7
  from sentence_transformers import SentenceTransformer
8
  from huggingface_hub import InferenceClient
9
 
 
13
  descriptions = df["brief_description"].astype(str).tolist()
14
  codes = df["hts8"].astype(str).tolist()
15
 
16
+ # --- Embedding model ---
17
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
 
 
 
18
 
19
+ # --- Load or compute embeddings + FAISS index ---
20
+ if os.path.exists("embeddings.npy") and os.path.exists("faiss.index"):
21
+ embeddings = np.load("embeddings.npy")
22
+ index = faiss.read_index("faiss.index")
23
+ else:
24
+ embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
25
+ faiss.normalize_L2(embeddings)
26
+ index = faiss.IndexFlatIP(embeddings.shape[1])
27
+ index.add(embeddings)
28
+ np.save("embeddings.npy", embeddings)
29
+ faiss.write_index(index, "faiss.index")
30
 
31
+ # --- Inference API client ---
32
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
33
 
34
+ def respond(message, history: list[dict]):
35
+ # 1. encode query and retrieve context
36
+ query_embedding = embedding_model.encode([message], convert_to_numpy=True)
37
  faiss.normalize_L2(query_embedding)
38
  _, indices = index.search(query_embedding, k=5)
39
 
40
  context = "\n".join([f"{codes[i]}: {descriptions[i]}" for i in indices[0]])
 
41
 
42
+ # 2. prepare system prompt with role + retrieved context
43
+ system_prompt = f"""You are an expert assistant specialized in tariff classification.
44
+ Your job is to help users find the most appropriate tariff codes based on their description.
45
+ Use only the provided context below to answer.
46
 
47
+ Context:
48
+ {context}
49
+ """
50
+
51
+ # 3. insert system message at the beginning
52
+ messages = [{"role": "system", "content": system_prompt}]
53
+ messages += history + [{"role": "user", "content": message}]
54
+
55
+ response = {"role": "assistant", "content": ""}
56
+
57
+ for message in client.chat_completion(
58
+ messages,
59
+ max_tokens=512,
60
+ stream=True,
61
  temperature=0.7,
62
+ top_p=0.95,
63
+ ):
64
+ token = message.choices[0].delta.content
65
+ response["content"] += token
66
+ yield response
67
+
68
+
69
+ demo = gr.ChatInterface(respond, type="messages")
70
 
71
+ if __name__ == "__main__":
72
+ demo.launch()