broadfield-dev commited on
Commit
8b13354
·
verified ·
1 Parent(s): 80d005a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -13
app.py CHANGED
@@ -1,34 +1,105 @@
1
- from flask import Flask, render_template, request
 
 
2
  from datasets import load_dataset
3
  import torch
4
  from transformers import AutoTokenizer, AutoModel
5
  import numpy as np
 
6
 
7
  # --- 1. Initialize Flask App ---
8
  app = Flask(__name__)
 
 
9
 
10
- # --- 2. Load Models and Dataset (Done once on startup) ---
11
- print("Loading models and dataset...")
12
- # Point this to your Hugging Face Dataset repository
13
- DATASET_REPO = "YourUsername/bible-rag-gemma-with-faiss"
14
- MODEL_NAME = "google/embeddinggemma-300m"
15
 
16
- # Load the pre-built dataset and FAISS index
17
- rag_dataset = load_dataset(DATASET_REPO)['train']
 
 
18
 
19
- # Load the Gemma model and tokenizer
20
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
- embedding_model = AutoModel.from_pretrained(MODEL_NAME)
22
- print("Models and dataset loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # --- 3. Define App Routes ---
25
 
26
  @app.route('/')
27
  def home():
 
 
28
  return render_template('index.html')
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  @app.route('/search', methods=['POST'])
31
  def search():
 
 
 
 
 
 
 
 
 
32
  user_query = request.form['query']
33
  if not user_query:
34
  return render_template('index.html', results=[])
@@ -39,7 +110,6 @@ def search():
39
  outputs = embedding_model(**inputs)
40
  query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
41
 
42
- # FAISS expects a flattened numpy array
43
  query_embedding = np.float32(query_embedding)
44
 
45
  # --- Search the FAISS index ---
 
1
+ import sys
2
+ import subprocess
3
+ from flask import Flask, render_template, request, flash, redirect, url_for
4
  from datasets import load_dataset
5
  import torch
6
  from transformers import AutoTokenizer, AutoModel
7
  import numpy as np
8
+ import os
9
 
10
  # --- 1. Initialize Flask App ---
11
  app = Flask(__name__)
12
+ # A secret key is needed for flashing messages to the user's session
13
+ app.secret_key = os.urandom(24)
14
 
15
+ # --- 2. Configuration & Resource Loading ---
16
+ print("Starting application...")
 
 
 
17
 
18
+ # Point this to the Hugging Face Dataset repository you want to create/use.
19
+ # This MUST match the DATASET_REPO in build_rag.py
20
+ DATASET_REPO = "broadfield-dev/bible-rag-dataset-gemma"
21
+ MODEL_NAME = "google/gemma-2b" # Use a consistent model for embedding and searching
22
 
23
+ # Global variables for the dataset and models
24
+ rag_dataset = None
25
+ tokenizer = None
26
+ embedding_model = None
27
+
28
+ def load_resources():
29
+ """
30
+ Attempts to load the dataset and models from the Hugging Face Hub.
31
+ Returns True on success, False on failure.
32
+ """
33
+ global rag_dataset, tokenizer, embedding_model
34
+ if rag_dataset:
35
+ return True
36
+
37
+ print(f"Attempting to load resources: {DATASET_REPO} and {MODEL_NAME}")
38
+ try:
39
+ # Load the pre-built dataset with the FAISS index
40
+ rag_dataset = load_dataset(DATASET_REPO)['train']
41
+
42
+ # Load the Gemma model and tokenizer
43
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
44
+ embedding_model = AutoModel.from_pretrained(MODEL_NAME)
45
+
46
+ print("Models and dataset loaded successfully!")
47
+ return True
48
+ except Exception as e:
49
+ print(f"Could not load RAG dataset from '{DATASET_REPO}'. It may not exist yet.")
50
+ print(f"Error: {e}")
51
+ # Reset globals to ensure a clean state
52
+ rag_dataset = None
53
+ tokenizer = None
54
+ embedding_model = None
55
+ return False
56
+
57
+ # Try to load resources on startup. The app can still run if this fails.
58
+ resources_loaded = load_resources()
59
 
60
  # --- 3. Define App Routes ---
61
 
62
  @app.route('/')
63
  def home():
64
+ if not resources_loaded:
65
+ flash(f"Welcome! The required RAG dataset '{DATASET_REPO}' is not loaded. Please use the 'Build RAG Dataset' button to create and upload it.", "warning")
66
  return render_template('index.html')
67
 
68
+ @app.route('/build-rag', methods=['POST'])
69
+ def build_rag_route():
70
+ """
71
+ Triggers the build_rag.py script as a background process.
72
+ NOTE: This requires a Hugging Face token with 'write' permissions
73
+ to be saved as a secret named HF_TOKEN in the Space settings.
74
+ """
75
+ print("RAG build process requested.")
76
+ try:
77
+ # Use Popen to run the script in the background without blocking the app.
78
+ process = subprocess.Popen(
79
+ [sys.executable, "build_rag.py"],
80
+ stdout=subprocess.PIPE,
81
+ stderr=subprocess.STDOUT,
82
+ text=True
83
+ )
84
+ print(f"Started build process with PID: {process.pid}")
85
+ flash("RAG build process initiated! This will run in the background and can take several minutes. Please check the Space logs for progress. Once complete, you can start searching.", "info")
86
+ except Exception as e:
87
+ print(f"Failed to start build process: {e}")
88
+ flash(f"An error occurred while trying to start the build process: {e}", "error")
89
+
90
+ return redirect(url_for('home'))
91
+
92
  @app.route('/search', methods=['POST'])
93
  def search():
94
+ global resources_loaded
95
+ # If resources weren't loaded, try again in case the build just finished.
96
+ if not resources_loaded:
97
+ print("Resources not loaded. Attempting to reload for search...")
98
+ resources_loaded = load_resources()
99
+ if not resources_loaded:
100
+ flash("The RAG dataset is not ready yet. Please wait for the build process to complete or check the logs for errors.", "error")
101
+ return redirect(url_for('home'))
102
+
103
  user_query = request.form['query']
104
  if not user_query:
105
  return render_template('index.html', results=[])
 
110
  outputs = embedding_model(**inputs)
111
  query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
112
 
 
113
  query_embedding = np.float32(query_embedding)
114
 
115
  # --- Search the FAISS index ---