tariff_codes / app.py
dxnxk's picture
Update app.py
de6b7d7 verified
import os
import sys
import pandas as pd
import numpy as np
import faiss
import gradio as gr
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
# --- Load and clean data ---
df = pd.read_csv("tariff_codes.csv", encoding="latin1", low_memory=False)
df.columns = df.columns.str.strip()
descriptions = df["brief_description"].astype(str).tolist()
codes = df["hts8"].astype(str).tolist()
# --- Embedding model ---
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# --- Load or compute embeddings + FAISS index ---
if os.path.exists("embeddings.npy") and os.path.exists("faiss.index"):
embeddings = np.load("embeddings.npy")
index = faiss.read_index("faiss.index")
else:
embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
faiss.normalize_L2(embeddings)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
np.save("embeddings.npy", embeddings)
faiss.write_index(index, "faiss.index")
# --- Inference API client ---
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3", token=os.getenv("HF_TOKEN"))
def respond(message, history):
query_embedding = embedding_model.encode([message], convert_to_numpy=True)
faiss.normalize_L2(query_embedding)
_, indices = index.search(query_embedding, k=5)
context = "\n".join([f"{codes[i]}: {descriptions[i]}" for i in indices[0]])
system_prompt = f"""You are an expert assistant specialized in tariff classification.
Your job is to help users find the most appropriate tariff codes based on their description.
Use only the provided context below to answer.
Context:
{context}
"""
messages = [{"role": "system", "content": system_prompt}]
messages += history + [{"role": "user", "content": message}]
full_response = ""
for chunk in client.chat_completion(
messages,
max_tokens=512,
stream=True,
temperature=0.7,
top_p=0.95,
):
token = chunk.choices[0].delta.content
if token:
full_response += token
yield full_response.replace("\n", "\n\n")
demo = gr.ChatInterface(
respond,
type="messages",
title="Tariff Code Bot",
description="Ask questions about tariff codes using natural language."
)
if __name__ == "__main__":
demo.launch(share=True)