broadfield-dev commited on
Commit
5ef463a
·
verified ·
1 Parent(s): 4dc4b99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -1,37 +1,37 @@
 
 
1
  import sys
2
  import subprocess
3
  from flask import Flask, render_template, request, flash, redirect, url_for, jsonify
4
  import torch
5
- import torch.nn.functional as F # Import the functional module
6
  from transformers import AutoTokenizer, AutoModel
7
  import os
8
  import chromadb
9
  from huggingface_hub import snapshot_download
10
 
11
- # (App setup and load_resources function are unchanged)
12
  app = Flask(__name__)
13
  app.secret_key = os.urandom(24)
14
 
15
  CHROMA_PATH = "chroma_db"
16
  COLLECTION_NAME = "bible_verses"
17
- # *** CHANGE 1: UPDATE THE MODEL NAME ***
18
- MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
19
- # *** CHANGE 2: UPDATE THE DATASET REPO NAME ***
20
- DATASET_REPO = "broadfield-dev/bible-chromadb-mpnet"
21
  STATUS_FILE = "build_status.log"
22
 
23
  chroma_collection = None
24
  tokenizer = None
25
  embedding_model = None
26
 
27
- # Mean Pooling Function - Take attention mask into account for correct averaging
28
  def mean_pooling(model_output, attention_mask):
29
- token_embeddings = model_output[0] #First element of model_output contains all token embeddings
30
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
31
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
32
 
33
  def load_resources():
34
- # (This function is unchanged)
35
  global chroma_collection, tokenizer, embedding_model
36
  if chroma_collection and embedding_model: return True
37
  print("Attempting to load resources...")
@@ -95,13 +95,14 @@ def search():
95
  if not user_query:
96
  return render_template('index.html', results=[])
97
 
98
- # *** CHANGE 3: USE THE CORRECT POOLING STRATEGY FOR SBERT MODELS ***
99
  encoded_input = tokenizer([user_query], padding=True, truncation=True, return_tensors='pt')
100
  with torch.no_grad():
101
  model_output = embedding_model(**encoded_input)
102
 
103
- sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
104
- query_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
 
 
105
 
106
  search_results = chroma_collection.query(
107
  query_embeddings=query_embedding.cpu().tolist(),
 
1
+ # app.py (Updated for a model with pre-normalized embeddings)
2
+
3
  import sys
4
  import subprocess
5
  from flask import Flask, render_template, request, flash, redirect, url_for, jsonify
6
  import torch
7
+ import torch.nn.functional as F
8
  from transformers import AutoTokenizer, AutoModel
9
  import os
10
  import chromadb
11
  from huggingface_hub import snapshot_download
12
 
 
13
  app = Flask(__name__)
14
  app.secret_key = os.urandom(24)
15
 
16
  CHROMA_PATH = "chroma_db"
17
  COLLECTION_NAME = "bible_verses"
18
+ # *** CHANGE 1: USE A MODEL WITH NORMALIZED EMBEDDINGS ***
19
+ MODEL_NAME = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
20
+ # *** CHANGE 2: USE THE NEW REPO FOR THE NEW DATABASE ***
21
+ DATASET_REPO = "broadfield-dev/bible-chromadb-multi-qa-mpnet"
22
  STATUS_FILE = "build_status.log"
23
 
24
  chroma_collection = None
25
  tokenizer = None
26
  embedding_model = None
27
 
28
+ # Mean Pooling Function - Crucial for sentence-transformer models
29
  def mean_pooling(model_output, attention_mask):
30
+ token_embeddings = model_output[0]
31
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
32
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
33
 
34
  def load_resources():
 
35
  global chroma_collection, tokenizer, embedding_model
36
  if chroma_collection and embedding_model: return True
37
  print("Attempting to load resources...")
 
95
  if not user_query:
96
  return render_template('index.html', results=[])
97
 
 
98
  encoded_input = tokenizer([user_query], padding=True, truncation=True, return_tensors='pt')
99
  with torch.no_grad():
100
  model_output = embedding_model(**encoded_input)
101
 
102
+ query_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
103
+
104
+ # *** REMOVED: NO LONGER NEED TO NORMALIZE THE QUERY EMBEDDING ***
105
+ # query_embedding = F.normalize(query_embedding, p=2, dim=1)
106
 
107
  search_results = chroma_collection.query(
108
  query_embeddings=query_embedding.cpu().tolist(),