mingbaer commited on
Commit
a50a5bf
·
verified ·
1 Parent(s): 29f636b

Coded Pull Relevant Info function

Browse files
Files changed (1) hide show
  1. app.py +24 -0
app.py CHANGED
@@ -15,3 +15,27 @@ import numpy as np
15
  with open("essay_writing.txt", "r", encoding="utf-8") as file:
16
  essay_writing = file.read()
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  with open("essay_writing.txt", "r", encoding="utf-8") as file:
16
  essay_writing = file.read()
17
 
18
+ # split the text into chunks
19
+ cleaned_text = essay_writing.strip()
20
+ chunks = cleaned_text.split("\n")
21
+ cleaned_chunks = [chunk.strip() for chunk in chunks if stripped_chunk]
22
+
23
+ # load an embedding model
24
+ model = SentenceTransformer('all-MiniLM-L6-v2')
25
+
26
+ chunk_embeddings = model.encode(cleaned_chunks, convert_to_tensor=True)
27
+
28
+ def pull_relevant_info(query):
29
+ query_embedding = model.encode(query, convert_to_tensor=True)
30
+ query_embedding_normalized = query_embedding / query_embedding.norm()
31
+
32
+ chunk_embeddings_normalized = chunk_embeddings / chunk_embeddings.norm(dim=1, keepdim=True)
33
+
34
+ similarities = torch.matmul(chunk_embeddings_normalized, query_embedding_normalized)
35
+
36
+ top_indices = torch.topk(similarities, k=3).indices.cpu().numpy()
37
+
38
+ relevant_info = "\n".join([chunks[i] for i in top_indices])
39
+
40
+ return relevant_info
41
+