davidepanza commited on
Commit
b32f886
·
verified ·
1 Parent(s): 411f450

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -3,7 +3,8 @@ from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
  from pydantic import BaseModel
5
  from fastapi.middleware.cors import CORSMiddleware
6
- from sentence_transformers import SentenceTransformer
 
7
  import lancedb
8
  import os
9
 
@@ -25,7 +26,14 @@ class BookRequest(BaseModel):
25
  query: str
26
  limit: int = 5
27
 
28
- model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
 
 
 
29
 
30
  db = lancedb.connect(
31
  uri=os.getenv("LANCEDB_URI", "your-lancedb-uri"),
@@ -53,7 +61,8 @@ def search_books(book_request: BookRequest):
53
  raise HTTPException(status_code=400, detail="Query string is required")
54
 
55
  # Get query embeddings
56
- query_embedding = model.encode(book_request.query)
 
57
 
58
  table = db.open_table(os.getenv("LANCEDB_TABLE", "book_db"))
59
  results = (table.search(query_embedding)
 
3
  from fastapi.responses import FileResponse
4
  from pydantic import BaseModel
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ from transformers import AutoTokenizer, AutoModel
7
+ import torch
8
  import lancedb
9
  import os
10
 
 
26
  query: str
27
  limit: int = 5
28
 
29
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
30
+ model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
31
+
32
+ def encode_query(text):
33
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+ return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
37
 
38
  db = lancedb.connect(
39
  uri=os.getenv("LANCEDB_URI", "your-lancedb-uri"),
 
61
  raise HTTPException(status_code=400, detail="Query string is required")
62
 
63
  # Get query embeddings
64
+ query_embedding = encode_query(book_request.query)
65
+ #query_embedding = model.encode(book_request.query)
66
 
67
  table = db.open_table(os.getenv("LANCEDB_TABLE", "book_db"))
68
  results = (table.search(query_embedding)