Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import faiss | |
| import numpy as np | |
| import torch | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ========================================== | |
| # 1. Load Dataset | |
| # ========================================== | |
| dataset = load_dataset( | |
| "electricsheepafrica/nigerian_retail_and_ecommerce_product_catalog_data", | |
| split="train" | |
| ) | |
| df = pd.DataFrame(dataset) | |
| # ========================================== | |
| # 2. Clean Data | |
| # ========================================== | |
| for col in ["product_name", "description", "category", "sub_category", "price"]: | |
| if col not in df.columns: | |
| df[col] = "" | |
| else: | |
| df[col] = df[col].fillna("").astype(str).str.strip() | |
| # Merging texts with categories | |
| df["text"] = df["product_name"] + " " + df["description"] + " " + df["category"] + " " + df["sub_category"] | |
| df = df.drop_duplicates(subset=["text"]) | |
| df = df.head(2000) # To reduce memory consumption on Space | |
| # ========================================== | |
| # 3. Create Embeddings + FAISS Index | |
| # ========================================== | |
| embedding_model = SentenceTransformer("all-mpnet-base-v2") # embeddings | |
| embeddings = embedding_model.encode(df["text"].tolist(), convert_to_numpy=True, show_progress_bar=True) | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexHNSWFlat(dimension, 32) | |
| index.add(embeddings) | |
| # ========================================== | |
| # 4. Load LLM | |
| # ========================================== | |
| MODEL_ID = "microsoft/phi-2" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| torch_dtype=torch.float32 | |
| ) | |
| # ========================================== | |
| # 5. Retrieval Function | |
| # ========================================== | |
| def retrieve_products(query, top_k=3): | |
| query_emb = embedding_model.encode([query]) | |
| D, I = index.search(np.array(query_emb), top_k) | |
| return df.iloc[I[0]] | |
| # ========================================== | |
| # 6. PROMPT TEMPLATE | |
| # ========================================== | |
| PROMPT_TEMPLATE = """You are "Jawad", a professional and friendly sales assistant for our online store. | |
| Answer in a natural, human-like, conversational style. | |
| Use the provided context to answer the customer's question accurately. | |
| Guidelines: | |
| 1. Be polite and professional. | |
| 2. If you don't find the answer in the context, say: "I'm sorry, I couldn't find information about that. Could you provide more details?" | |
| 3. Do not make up any information. | |
| 4. If the customer asks in Arabic, answer in Arabic; if the customer asks in English, answer in English. | |
| Context: {context} | |
| Question: {question} | |
| Jawad's Response:""" | |
| # ========================================== | |
| # 7. Chat Function | |
| # ========================================== | |
| def sales_chat(message, history): | |
| retrieved = retrieve_products(message) | |
| context = "" | |
| for _, row in retrieved.iterrows(): | |
| product_name = row.get('product_name', '') | |
| description = row.get('description', '') | |
| price = row.get('price', '') | |
| context += f"Product: {product_name}\nDescription: {description}\nPrice: {price}\n" | |
| prompt = PROMPT_TEMPLATE.format(context=context, question=message) | |
| inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=250, | |
| do_sample=True, | |
| temperature=0.8, | |
| top_p=0.95, | |
| top_k=60, | |
| repetition_penalty=1.1 # Reduces repetition | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response.replace(prompt, "").strip() | |
| response = response | |
| return response | |
| # ========================================== | |
| # 8. Gradio Interface | |
| # ========================================== | |
| demo = gr.ChatInterface( | |
| fn=sales_chat, | |
| title="Jawad - Sales Agent", | |
| description="Your smart sales advisor — delivering fast, relevant product recommendations tailored to your needs." | |
| ) | |
| demo.launch() |