Spaces:
Paused
Paused
| from flask import Flask, render_template, request | |
| from datasets import load_dataset | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| import numpy as np | |
| # --- 1. Initialize Flask App --- | |
| app = Flask(__name__) | |
| # --- 2. Load Models and Dataset (Done once on startup) --- | |
| print("Loading models and dataset...") | |
| # Point this to your Hugging Face Dataset repository | |
| DATASET_REPO = "YourUsername/bible-rag-gemma-with-faiss" | |
| MODEL_NAME = "google/embeddinggemma-300m" | |
| # Load the pre-built dataset and FAISS index | |
| rag_dataset = load_dataset(DATASET_REPO)['train'] | |
| # Load the Gemma model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| embedding_model = AutoModel.from_pretrained(MODEL_NAME) | |
| print("Models and dataset loaded successfully!") | |
| # --- 3. Define App Routes --- | |
| def home(): | |
| return render_template('index.html') | |
| def search(): | |
| user_query = request.form['query'] | |
| if not user_query: | |
| return render_template('index.html', results=[]) | |
| # --- Create embedding for the user's query --- | |
| inputs = tokenizer(user_query, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = embedding_model(**inputs) | |
| query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
| # FAISS expects a flattened numpy array | |
| query_embedding = np.float32(query_embedding) | |
| # --- Search the FAISS index --- | |
| scores, retrieved_examples = rag_dataset.get_nearest_examples( | |
| 'embeddings', | |
| query_embedding, | |
| k=10 # Get top 10 results | |
| ) | |
| # --- Format results for display --- | |
| results_list = [] | |
| for i in range(len(scores)): | |
| results_list.append({ | |
| 'score': scores[i], | |
| 'text': retrieved_examples['text'][i], | |
| 'reference': retrieved_examples['reference'][i], | |
| 'version': retrieved_examples['version'][i] | |
| }) | |
| return render_template('index.html', results=results_list, query=user_query) | |
| # --- 4. Run the App --- | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |