Sales-Agent / app.py
Moaath7x's picture
Update app.py
6d4391a verified
import gradio as gr
import pandas as pd
import faiss
import numpy as np
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
# ==========================================
# 1. Load Dataset
# ==========================================
dataset = load_dataset(
"electricsheepafrica/nigerian_retail_and_ecommerce_product_catalog_data",
split="train"
)
df = pd.DataFrame(dataset)
# ==========================================
# 2. Clean Data
# ==========================================
for col in ["product_name", "description", "category", "sub_category", "price"]:
if col not in df.columns:
df[col] = ""
else:
df[col] = df[col].fillna("").astype(str).str.strip()
# Merging texts with categories
df["text"] = df["product_name"] + " " + df["description"] + " " + df["category"] + " " + df["sub_category"]
df = df.drop_duplicates(subset=["text"])
df = df.head(2000) # To reduce memory consumption on Space
# ==========================================
# 3. Create Embeddings + FAISS Index
# ==========================================
embedding_model = SentenceTransformer("all-mpnet-base-v2") # embeddings
embeddings = embedding_model.encode(df["text"].tolist(), convert_to_numpy=True, show_progress_bar=True)
dimension = embeddings.shape[1]
index = faiss.IndexHNSWFlat(dimension, 32)
index.add(embeddings)
# ==========================================
# 4. Load LLM
# ==========================================
MODEL_ID = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.float32
)
# ==========================================
# 5. Retrieval Function
# ==========================================
def retrieve_products(query, top_k=3):
query_emb = embedding_model.encode([query])
D, I = index.search(np.array(query_emb), top_k)
return df.iloc[I[0]]
# ==========================================
# 6. PROMPT TEMPLATE
# ==========================================
PROMPT_TEMPLATE = """You are "Jawad", a professional and friendly sales assistant for our online store.
Answer in a natural, human-like, conversational style.
Use the provided context to answer the customer's question accurately.
Guidelines:
1. Be polite and professional.
2. If you don't find the answer in the context, say: "I'm sorry, I couldn't find information about that. Could you provide more details?"
3. Do not make up any information.
4. If the customer asks in Arabic, answer in Arabic; if the customer asks in English, answer in English.
Context: {context}
Question: {question}
Jawad's Response:"""
# ==========================================
# 7. Chat Function
# ==========================================
def sales_chat(message, history):
retrieved = retrieve_products(message)
context = ""
for _, row in retrieved.iterrows():
product_name = row.get('product_name', '')
description = row.get('description', '')
price = row.get('price', '')
context += f"Product: {product_name}\nDescription: {description}\nPrice: {price}\n"
prompt = PROMPT_TEMPLATE.format(context=context, question=message)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=250,
do_sample=True,
temperature=0.8,
top_p=0.95,
top_k=60,
repetition_penalty=1.1 # Reduces repetition
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.replace(prompt, "").strip()
response = response
return response
# ==========================================
# 8. Gradio Interface
# ==========================================
demo = gr.ChatInterface(
fn=sales_chat,
title="Jawad - Sales Agent",
description="Your smart sales advisor — delivering fast, relevant product recommendations tailored to your needs."
)
demo.launch()