Spaces:
Sleeping
Sleeping
File size: 2,357 Bytes
bc1eeb8 24b8616 bc1eeb8 24b8616 bb361c7 3e2e2ae 9ca8b14 3e2e2ae 24b8616 bc1eeb8 24b8616 bc1eeb8 24b8616 bc1eeb8 de6b7d7 02bf09e b46b89c bc1eeb8 02bf09e bc1eeb8 02bf09e bc1eeb8 b46b89c bc1eeb8 02bf09e bc1eeb8 b46b89c bc1eeb8 76b40e0 b039489 76b40e0 24b8616 bc1eeb8 f4de068 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | import os
import sys
import pandas as pd
import numpy as np
import faiss
import gradio as gr
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
# --- Load and clean data ---
df = pd.read_csv("tariff_codes.csv", encoding="latin1", low_memory=False)
df.columns = df.columns.str.strip()
descriptions = df["brief_description"].astype(str).tolist()
codes = df["hts8"].astype(str).tolist()
# --- Embedding model ---
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# --- Load or compute embeddings + FAISS index ---
if os.path.exists("embeddings.npy") and os.path.exists("faiss.index"):
embeddings = np.load("embeddings.npy")
index = faiss.read_index("faiss.index")
else:
embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
faiss.normalize_L2(embeddings)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
np.save("embeddings.npy", embeddings)
faiss.write_index(index, "faiss.index")
# --- Inference API client ---
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3", token=os.getenv("HF_TOKEN"))
def respond(message, history):
query_embedding = embedding_model.encode([message], convert_to_numpy=True)
faiss.normalize_L2(query_embedding)
_, indices = index.search(query_embedding, k=5)
context = "\n".join([f"{codes[i]}: {descriptions[i]}" for i in indices[0]])
system_prompt = f"""You are an expert assistant specialized in tariff classification.
Your job is to help users find the most appropriate tariff codes based on their description.
Use only the provided context below to answer.
Context:
{context}
"""
messages = [{"role": "system", "content": system_prompt}]
messages += history + [{"role": "user", "content": message}]
full_response = ""
for chunk in client.chat_completion(
messages,
max_tokens=512,
stream=True,
temperature=0.7,
top_p=0.95,
):
token = chunk.choices[0].delta.content
if token:
full_response += token
yield full_response.replace("\n", "\n\n")
demo = gr.ChatInterface(
respond,
type="messages",
title="Tariff Code Bot",
description="Ask questions about tariff codes using natural language."
)
if __name__ == "__main__":
demo.launch(share=True) |