broadfield-dev commited on
Commit
061fe70
·
verified ·
1 Parent(s): 15d7ab6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request
2
+ from datasets import load_dataset
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import numpy as np
6
+
7
+ # --- 1. Initialize Flask App ---
8
+ app = Flask(__name__)
9
+
10
+ # --- 2. Load Models and Dataset (Done once on startup) ---
11
+ print("Loading models and dataset...")
12
+ # Point this to your Hugging Face Dataset repository
13
+ DATASET_REPO = "YourUsername/bible-rag-gemma-with-faiss"
14
+ MODEL_NAME = "google/embeddinggemma-300m"
15
+
16
+ # Load the pre-built dataset and FAISS index
17
+ rag_dataset = load_dataset(DATASET_REPO)['train']
18
+
19
+ # Load the Gemma model and tokenizer
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ embedding_model = AutoModel.from_pretrained(MODEL_NAME)
22
+ print("Models and dataset loaded successfully!")
23
+
24
+ # --- 3. Define App Routes ---
25
+
26
+ @app.route('/')
27
+ def home():
28
+ return render_template('index.html')
29
+
30
+ @app.route('/search', methods=['POST'])
31
+ def search():
32
+ user_query = request.form['query']
33
+ if not user_query:
34
+ return render_template('index.html', results=[])
35
+
36
+ # --- Create embedding for the user's query ---
37
+ inputs = tokenizer(user_query, return_tensors="pt")
38
+ with torch.no_grad():
39
+ outputs = embedding_model(**inputs)
40
+ query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
41
+
42
+ # FAISS expects a flattened numpy array
43
+ query_embedding = np.float32(query_embedding)
44
+
45
+ # --- Search the FAISS index ---
46
+ scores, retrieved_examples = rag_dataset.get_nearest_examples(
47
+ 'embeddings',
48
+ query_embedding,
49
+ k=10 # Get top 10 results
50
+ )
51
+
52
+ # --- Format results for display ---
53
+ results_list = []
54
+ for i in range(len(scores)):
55
+ results_list.append({
56
+ 'score': scores[i],
57
+ 'text': retrieved_examples['text'][i],
58
+ 'reference': retrieved_examples['reference'][i],
59
+ 'version': retrieved_examples['version'][i]
60
+ })
61
+
62
+ return render_template('index.html', results=results_list, query=user_query)
63
+
64
+ # --- 4. Run the App ---
65
+ if __name__ == '__main__':
66
+ app.run(host='0.0.0.0', port=7860)