File size: 5,843 Bytes
6aa994f 0768472 6aa994f 6c7ce7e 47cf8b2 5324eac 6aa994f 6c7ce7e 5f7526f 6c7ce7e 6aa994f 5f7526f 5324eac 6aa994f c476445 6aa994f da50bf2 6aa994f 47cf8b2 65b7e4a b4428d3 47cf8b2 6aa994f 0768472 b6625a4 da50bf2 b6625a4 47cf8b2 da50bf2 b6625a4 c476445 b6625a4 bd9071b da50bf2 6aa994f c476445 6aa994f da50bf2 6aa994f 7e9fae4 6c7ce7e 7e9fae4 e27236f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from flask import Flask, render_template, request, jsonify
from qdrant_client import QdrantClient
from qdrant_client import models
from qdrant_client.models import Batch, PointStruct
from pickle import load, dump
import numpy as np
import os, time, sys
from datetime import datetime as dt
from datetime import timedelta
from datetime import timezone
import io
import requests
import torch.nn.functional as F
import torch
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
app = Flask(__name__)
# Faster Whisper setup
# model_size = 'small'
# beamsize = 2
# wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8")
# Initialize Qdrant Client and other required settings
qdrant_api_key = os.environ.get("qdrant_api_key")
qdrant_url = os.environ.get("qdrant_url")
client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2')
model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device)
def e5embed(query):
batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
outputs = model(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1)
embeddings = embeddings.cpu().detach().numpy().flatten().tolist()
return embeddings
def get_id(collection):
resp = client.scroll(collection_name=collection, limit=10000, with_payload=True, with_vectors=False,)
max_id = max([r.id for r in resp[0]])+1
return int(max_id)
@app.route("/")
def index():
return render_template("index.html")
@app.route("/search", methods=["POST"])
def search():
query = request.form["query"]
collection_name = request.form["collection"]
topN = 200 # Define your topN value
print('QUERY: ',query)
if query.strip().startswith('tilc:'):
collection_name = 'tils'
qvector = "context"
query = query.replace('tilc:', '')
elif query.strip().startswith('til:'):
collection_name = 'tils'
qvector = "title"
query = query.replace('til:', '')
else: collection_name = 'jks'
timh = time.time()
sq = e5embed(query)
print('EMBEDDING TIME: ', time.time() - timh)
timh = time.time()
if collection_name == "jks":
data = {"vector": sq, "with_payload": True, "limit": topN}
response = requests.post(qdrant_url+f'/jks/points/search', json=data, headers={'Content-Type': 'application/json'})
results = response.json()
# results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN)
else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100)
print('SEARCH TIME: ', time.time() - timh)
#print(results[0])
# try:
new_results = []
if collection_name == 'jks':
for r in results:
if 'date' not in r['payload']: r['payload']['date'] = '20200101'
new_results.append({"text": r['payload']['text'], "date": str(int(r['payload']['date'])), "id": r['id']}) # Implement your Qdrant search here
else:
for r in results:
if 'context' in r.payload and r.payload['context'] != '':
if 'date' not in r.payload: r.payload['date'] = '20200101'
new_results.append({"text": r.payload['title'] + '<br>Context: ' + r.payload['context'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
else:
if 'date' not in r.payload: r.payload['date'] = '20200101'
new_results.append({"text": r.payload['title'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
return jsonify(new_results)
# except:
# return jsonify([])
@app.route("/add_item", methods=["POST"])
def add_item():
title = request.form["title"]
url = request.form["url"]
if url.strip() == '':
collection_name = 'jks'
cid = get_id(collection_name)
print('cid', cid, time.strftime("%Y%m%d"))
resp = client.upsert(collection_name=collection_name, points=Batch(ids=[cid], payloads=[{'text':title, 'date': time.strftime("%Y%m%d")}],vectors=[e5embed(title)]),)
else:
collection_name = 'tils'
cid = get_id('tils')
print('cid', cid, time.strftime("%Y%m%d"), collection_name)
til = {'title': title.replace('TIL that', '').replace('TIL:', '').replace('TIL ', '').strip(), 'url': url.replace('https://', '').replace('http://', ''), "date": time.strftime("%Y%m%d_%H%M")}
resp = client.upsert(collection_name="tils", points=[PointStruct(id=cid, payload=til, vector={"title": e5embed(til['title']),},)])
print('Upsert response:', resp)
return jsonify({"success": True, "index": collection_name})
@app.route("/delete_joke", methods=["POST"])
def delete_joke():
joke_id = request.form["id"]
collection_name = request.form["collection"]
print('Deleting no.', joke_id, 'from collection', collection_name)
client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),)
return jsonify({"deleted": True})
if __name__ == "__main__":
app.run(host="0.0.0.0", debug=True, port=7860) |