Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| import torch.nn.functional as F | |
| # Sample FAQ data (you can load this from CSV or DB) | |
| faq_data = [ | |
| {"question": "What is your return policy?", "answer": "You can return items within 30 days of purchase."}, | |
| {"question": "How can I track my order?", "answer": "Use the tracking link sent to your email after shipping."}, | |
| {"question": "Do you offer international shipping?", "answer": "Yes, we ship to over 50 countries worldwide."}, | |
| {"question": "How do I reset my password?", "answer": "Click on 'Forgot Password' on the login page."} | |
| ] | |
| # Load model and tokenizer | |
| model_name = "openai-community/gpt2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.pad_token = tokenizer.eos_token # β Fix: set pad token to avoid padding error | |
| model = AutoModel.from_pretrained(model_name) | |
| # Mean pooling for sentence embedding | |
| def mean_pooling(model_output, attention_mask): | |
| token_embeddings = model_output[0] # first element: last hidden state | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()) | |
| return torch.sum(token_embeddings * input_mask_expanded, dim=1) / torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) | |
| # Generate embeddings | |
| def embed_texts(texts): | |
| encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') | |
| with torch.no_grad(): | |
| model_output = model(**encoded_input) | |
| return mean_pooling(model_output, encoded_input['attention_mask']) | |
| # Precompute FAQ question embeddings | |
| faq_questions = [item["question"] for item in faq_data] | |
| faq_embeddings = embed_texts(faq_questions) | |
| # Search function | |
| def search_faq(query): | |
| query_embedding = embed_texts([query]) | |
| similarities = F.cosine_similarity(query_embedding, faq_embeddings) | |
| best_idx = torch.argmax(similarities).item() | |
| best_match = faq_data[best_idx] | |
| return f"**Q:** {best_match['question']}\n\n**A:** {best_match['answer']}" | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=search_faq, | |
| inputs=gr.Textbox(label="Ask a question", placeholder="e.g., How do I reset my password?"), | |
| outputs=gr.Markdown(label="Most Relevant Answer"), | |
| title="π FAQ Semantic Search", | |
| description="Ask a question to find the most relevant FAQ using semantic similarity with a custom Mistral model." | |
| ) | |
| iface.launch() | |