File size: 2,357 Bytes
bc1eeb8
 
24b8616
 
 
bc1eeb8
24b8616
bb361c7
 
3e2e2ae
9ca8b14
3e2e2ae
 
 
24b8616
bc1eeb8
24b8616
 
bc1eeb8
 
 
 
 
 
 
 
 
 
 
24b8616
bc1eeb8
de6b7d7
02bf09e
b46b89c
bc1eeb8
02bf09e
 
 
 
 
bc1eeb8
 
 
02bf09e
bc1eeb8
 
 
 
 
 
 
b46b89c
 
bc1eeb8
 
 
02bf09e
bc1eeb8
 
b46b89c
 
 
 
 
bc1eeb8
76b40e0
b039489
 
 
76b40e0
 
24b8616
bc1eeb8
f4de068
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import sys
import pandas as pd
import numpy as np
import faiss
import gradio as gr
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient

# --- Load and clean data ---
df = pd.read_csv("tariff_codes.csv", encoding="latin1", low_memory=False)
df.columns = df.columns.str.strip()
descriptions = df["brief_description"].astype(str).tolist()
codes = df["hts8"].astype(str).tolist()

# --- Embedding model ---
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

# --- Load or compute embeddings + FAISS index ---
if os.path.exists("embeddings.npy") and os.path.exists("faiss.index"):
    embeddings = np.load("embeddings.npy")
    index = faiss.read_index("faiss.index")
else:
    embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
    faiss.normalize_L2(embeddings)
    index = faiss.IndexFlatIP(embeddings.shape[1])
    index.add(embeddings)
    np.save("embeddings.npy", embeddings)
    faiss.write_index(index, "faiss.index")

# --- Inference API client ---
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3", token=os.getenv("HF_TOKEN"))

def respond(message, history):
    query_embedding = embedding_model.encode([message], convert_to_numpy=True)
    faiss.normalize_L2(query_embedding)
    _, indices = index.search(query_embedding, k=5)

    context = "\n".join([f"{codes[i]}: {descriptions[i]}" for i in indices[0]])

    system_prompt = f"""You are an expert assistant specialized in tariff classification.
Your job is to help users find the most appropriate tariff codes based on their description.
Use only the provided context below to answer.

Context:
{context}
"""

    messages = [{"role": "system", "content": system_prompt}]
    messages += history + [{"role": "user", "content": message}]

    full_response = ""
    for chunk in client.chat_completion(
        messages,
        max_tokens=512,
        stream=True,
        temperature=0.7,
        top_p=0.95,
    ):
        token = chunk.choices[0].delta.content
        if token:
            full_response += token
            yield full_response.replace("\n", "\n\n")


demo = gr.ChatInterface(
    respond,  
    type="messages",
    title="Tariff Code Bot",
    description="Ask questions about tariff codes using natural language."
)

if __name__ == "__main__":
    demo.launch(share=True)