fixed stuff
Browse files
app.py
CHANGED
|
@@ -10,8 +10,10 @@ import numpy as np
|
|
| 10 |
|
| 11 |
embedding_client = InferenceClient(model="sentence-transformers/all-MiniLM-L6-v2")
|
| 12 |
|
| 13 |
-
def
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
|
|
@@ -22,8 +24,6 @@ def embed(texts):
|
|
| 22 |
with open("gita.txt", "r", encoding="utf-8") as f:
|
| 23 |
raw_text = f.read()
|
| 24 |
|
| 25 |
-
doc_embeddings = np.array(embed(documents))
|
| 26 |
-
|
| 27 |
def chunk_text(text, chunk_size=500, overlap=50):
|
| 28 |
chunks = []
|
| 29 |
start = 0
|
|
@@ -33,6 +33,9 @@ def chunk_text(text, chunk_size=500, overlap=50):
|
|
| 33 |
start += chunk_size - overlap
|
| 34 |
return chunks
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
# Embedding model (small + free)
|
| 37 |
#embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 38 |
#doc_embeddings = #embedder.encode(documents)
|
|
@@ -40,7 +43,7 @@ def chunk_text(text, chunk_size=500, overlap=50):
|
|
| 40 |
#doc_embeddings = embedder.encode(documents)
|
| 41 |
|
| 42 |
def retrieve(query, top_k=4):
|
| 43 |
-
query_embedding =
|
| 44 |
scores = np.dot(doc_embeddings, query_embedding)
|
| 45 |
top_indices = np.argsort(scores)[-top_k:][::-1]
|
| 46 |
results = [documents[i] for i in top_indices]
|
|
|
|
| 10 |
|
| 11 |
embedding_client = InferenceClient(model="sentence-transformers/all-MiniLM-L6-v2")
|
| 12 |
|
| 13 |
+
def embed_texts(texts):
|
| 14 |
+
if isinstance(texts, str):
|
| 15 |
+
texts = [texts]
|
| 16 |
+
return np.array(embedding_client.feature_extraction(texts))
|
| 17 |
|
| 18 |
|
| 19 |
|
|
|
|
| 24 |
with open("gita.txt", "r", encoding="utf-8") as f:
|
| 25 |
raw_text = f.read()
|
| 26 |
|
|
|
|
|
|
|
| 27 |
def chunk_text(text, chunk_size=500, overlap=50):
|
| 28 |
chunks = []
|
| 29 |
start = 0
|
|
|
|
| 33 |
start += chunk_size - overlap
|
| 34 |
return chunks
|
| 35 |
|
| 36 |
+
documents = chunk_text(raw_text)
|
| 37 |
+
doc_embeddings = embed_texts(documents)
|
| 38 |
+
|
| 39 |
# Embedding model (small + free)
|
| 40 |
#embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 41 |
#doc_embeddings = #embedder.encode(documents)
|
|
|
|
| 43 |
#doc_embeddings = embedder.encode(documents)
|
| 44 |
|
| 45 |
def retrieve(query, top_k=4):
|
| 46 |
+
query_embedding = embed_texts(query)[0]
|
| 47 |
scores = np.dot(doc_embeddings, query_embedding)
|
| 48 |
top_indices = np.argsort(scores)[-top_k:][::-1]
|
| 49 |
results = [documents[i] for i in top_indices]
|