vickyvigneshmass commited on
Commit
836cff2
·
verified ·
1 Parent(s): a5faf3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -3
app.py CHANGED
@@ -1,7 +1,84 @@
1
- from fastapi import FastAPI
 
 
 
 
2
 
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
- def greet_json():
7
- return {"welcome": "Created!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Query, UploadFile, File
2
+ from sentence_transformers import SentenceTransformer, util
3
+ import torch
4
+ import pickle
5
+ import os
6
 
7
+ # FastAPI instance
8
  app = FastAPI()
9
 
10
+ # Global variables
11
+ MODEL_NAME = 'all-MiniLM-L6-v2'
12
+ EMBEDDING_CACHE = 'embeddings_cache.pkl'
13
+ DOCUMENT_PATH = 'test.txt'
14
+
15
+ model = SentenceTransformer(MODEL_NAME)
16
+ sentences = []
17
+ sentence_embeddings = None
18
+
19
+ # Function to load and encode document
20
+ def load_and_encode_document(file_path):
21
+ with open(file_path, "r", encoding="utf-8") as f:
22
+ document_text = f.read()
23
+ sents = [line.strip() for line in document_text.split('\n') if line.strip()]
24
+ embs = model.encode(sents, convert_to_tensor=True)
25
+ return sents, embs
26
+
27
+ # Load embeddings if cached
28
+ if os.path.exists(EMBEDDING_CACHE):
29
+ with open(EMBEDDING_CACHE, 'rb') as f:
30
+ sentences, sentence_embeddings = pickle.load(f)
31
+ else:
32
+ sentences, sentence_embeddings = load_and_encode_document(DOCUMENT_PATH)
33
+ with open(EMBEDDING_CACHE, 'wb') as f:
34
+ pickle.dump((sentences, sentence_embeddings), f)
35
+
36
  @app.get("/")
37
+ def welcome():
38
+ return {"message": "Document Retrieval Service is Running!"}
39
+
40
+ @app.get("/search")
41
+ def search_text(
42
+ text: str = Query(..., description="Enter your query"),
43
+ top_k: int = Query(5, description="Number of top matches to return"),
44
+ threshold: float = Query(0.5, description="Minimum similarity score threshold")
45
+ ):
46
+ query_embedding = model.encode(text, convert_to_tensor=True)
47
+ scores = util.cos_sim(query_embedding, sentence_embeddings)[0]
48
+ top_results = torch.topk(scores, k=top_k)
49
+
50
+ results = []
51
+ for idx in top_results.indices:
52
+ score = scores[idx].item()
53
+ if score >= threshold:
54
+ results.append({
55
+ "matched_sentence": sentences[idx],
56
+ "similarity_score": round(score, 3)
57
+ })
58
+
59
+ return {
60
+ "query": text,
61
+ "top_matches": results or "No relevant matches found above threshold."
62
+ }
63
+
64
+ @app.post("/upload")
65
+ async def upload_file(file: UploadFile = File(...)):
66
+ content = await file.read()
67
+ text = content.decode("utf-8")
68
+ with open(DOCUMENT_PATH, "w", encoding="utf-8") as f:
69
+ f.write(text)
70
+ global sentences, sentence_embeddings
71
+ sentences, sentence_embeddings = load_and_encode_document(DOCUMENT_PATH)
72
+ with open(EMBEDDING_CACHE, 'wb') as f:
73
+ pickle.dump((sentences, sentence_embeddings), f)
74
+ return {"message": f"File '{file.filename}' uploaded and processed successfully."}
75
+
76
+ @app.post("/load_model")
77
+ def load_model(model_name: str = Query(..., description="HuggingFace model name to load")):
78
+ global model, sentences, sentence_embeddings, MODEL_NAME
79
+ MODEL_NAME = model_name
80
+ model = SentenceTransformer(model_name)
81
+ sentences, sentence_embeddings = load_and_encode_document(DOCUMENT_PATH)
82
+ with open(EMBEDDING_CACHE, 'wb') as f:
83
+ pickle.dump((sentences, sentence_embeddings), f)
84
+ return {"message": f"Model '{model_name}' loaded and document re-embedded successfully."}