Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import pandas as pd | |
| import numpy as np | |
| import faiss | |
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import InferenceClient | |
| # --- Load and clean data --- | |
| df = pd.read_csv("tariff_codes.csv", encoding="latin1", low_memory=False) | |
| df.columns = df.columns.str.strip() | |
| descriptions = df["brief_description"].astype(str).tolist() | |
| codes = df["hts8"].astype(str).tolist() | |
| # --- Embedding model --- | |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # --- Load or compute embeddings + FAISS index --- | |
| if os.path.exists("embeddings.npy") and os.path.exists("faiss.index"): | |
| embeddings = np.load("embeddings.npy") | |
| index = faiss.read_index("faiss.index") | |
| else: | |
| embeddings = embedding_model.encode(descriptions, convert_to_numpy=True) | |
| faiss.normalize_L2(embeddings) | |
| index = faiss.IndexFlatIP(embeddings.shape[1]) | |
| index.add(embeddings) | |
| np.save("embeddings.npy", embeddings) | |
| faiss.write_index(index, "faiss.index") | |
| # --- Inference API client --- | |
| client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3", token=os.getenv("HF_TOKEN")) | |
| def respond(message, history): | |
| query_embedding = embedding_model.encode([message], convert_to_numpy=True) | |
| faiss.normalize_L2(query_embedding) | |
| _, indices = index.search(query_embedding, k=5) | |
| context = "\n".join([f"{codes[i]}: {descriptions[i]}" for i in indices[0]]) | |
| system_prompt = f"""You are an expert assistant specialized in tariff classification. | |
| Your job is to help users find the most appropriate tariff codes based on their description. | |
| Use only the provided context below to answer. | |
| Context: | |
| {context} | |
| """ | |
| messages = [{"role": "system", "content": system_prompt}] | |
| messages += history + [{"role": "user", "content": message}] | |
| full_response = "" | |
| for chunk in client.chat_completion( | |
| messages, | |
| max_tokens=512, | |
| stream=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| ): | |
| token = chunk.choices[0].delta.content | |
| if token: | |
| full_response += token | |
| yield full_response.replace("\n", "\n\n") | |
| demo = gr.ChatInterface( | |
| respond, | |
| type="messages", | |
| title="Tariff Code Bot", | |
| description="Ask questions about tariff codes using natural language." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |