Victoria31 commited on
Commit
da9b3be
·
verified ·
1 Parent(s): 5ee20de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -6,7 +6,7 @@ import requests
6
  import numpy as np
7
  import torch
8
  from sklearn.neighbors import NearestNeighbors
9
- from transformers import AutoTokenizer, AutoModel
10
 
11
  # --- CONFIGURATION ---
12
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
@@ -47,11 +47,8 @@ def chunk_text(text, max_chunk_length=500):
47
  return chunks
48
 
49
  def embed_texts(texts):
50
- encoded = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
51
- with torch.no_grad():
52
- model_output = model(**encoded)
53
- embeddings = model_output.last_hidden_state.mean(dim=1)
54
- return embeddings.cpu().numpy()
55
 
56
  def save_cache(embeddings, chunks):
57
  np.save(EMBEDDING_CACHE_FILE, embeddings)
@@ -106,7 +103,9 @@ def respond(message, history):
106
  response.raise_for_status()
107
  output = response.json()
108
  generated_text = output[0]["generated_text"]
109
- answer = generated_text.split("Answer:")[-1].strip()
 
 
110
 
111
  except Exception as e:
112
  print("API Error:", e)
@@ -121,8 +120,7 @@ def respond(message, history):
121
  # --- INIT SECTION ---
122
 
123
  # Load tokenizer and model for embeddings
124
- tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
125
- model = AutoModel.from_pretrained(EMBEDDING_MODEL)
126
 
127
  # Try to load cached embeddings and chunks
128
  chunk_embeddings, chunks = load_cache()
 
6
  import numpy as np
7
  import torch
8
  from sklearn.neighbors import NearestNeighbors
9
+ from sentence_transformers import SentenceTransformer
10
 
11
  # --- CONFIGURATION ---
12
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
 
47
  return chunks
48
 
49
  def embed_texts(texts):
50
+ return model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
51
+
 
 
 
52
 
53
  def save_cache(embeddings, chunks):
54
  np.save(EMBEDDING_CACHE_FILE, embeddings)
 
103
  response.raise_for_status()
104
  output = response.json()
105
  generated_text = output[0]["generated_text"]
106
+ match = re.search(r"Answer:(.*)", generated_text, re.DOTALL)
107
+ answer = match.group(1).strip() if match else generated_text.strip()
108
+
109
 
110
  except Exception as e:
111
  print("API Error:", e)
 
120
  # --- INIT SECTION ---
121
 
122
  # Load tokenizer and model for embeddings
123
+ model = SentenceTransformer(EMBEDDING_MODEL)
 
124
 
125
  # Try to load cached embeddings and chunks
126
  chunk_embeddings, chunks = load_cache()