gradio-bgi / app.py
pradeep4321's picture
Update app.py
9302e28 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
# Sample FAQ data (you can load this from CSV or DB)
faq_data = [
{"question": "What is your return policy?", "answer": "You can return items within 30 days of purchase."},
{"question": "How can I track my order?", "answer": "Use the tracking link sent to your email after shipping."},
{"question": "Do you offer international shipping?", "answer": "Yes, we ship to over 50 countries worldwide."},
{"question": "How do I reset my password?", "answer": "Click on 'Forgot Password' on the login page."}
]
# Load model and tokenizer
model_name = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token # βœ… Fix: set pad token to avoid padding error
model = AutoModel.from_pretrained(model_name)
# Mean pooling for sentence embedding
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # first element: last hidden state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
return torch.sum(token_embeddings * input_mask_expanded, dim=1) / torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
# Generate embeddings
def embed_texts(texts):
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**encoded_input)
return mean_pooling(model_output, encoded_input['attention_mask'])
# Precompute FAQ question embeddings
faq_questions = [item["question"] for item in faq_data]
faq_embeddings = embed_texts(faq_questions)
# Search function
def search_faq(query):
query_embedding = embed_texts([query])
similarities = F.cosine_similarity(query_embedding, faq_embeddings)
best_idx = torch.argmax(similarities).item()
best_match = faq_data[best_idx]
return f"**Q:** {best_match['question']}\n\n**A:** {best_match['answer']}"
# Gradio interface
iface = gr.Interface(
fn=search_faq,
inputs=gr.Textbox(label="Ask a question", placeholder="e.g., How do I reset my password?"),
outputs=gr.Markdown(label="Most Relevant Answer"),
title="πŸ” FAQ Semantic Search",
description="Ask a question to find the most relevant FAQ using semantic similarity with a custom Mistral model."
)
iface.launch()