|
|
from flask import Flask, render_template, request, jsonify |
|
|
from qdrant_client import QdrantClient |
|
|
from qdrant_client import models |
|
|
from qdrant_client.models import Batch, PointStruct |
|
|
from pickle import load, dump |
|
|
import numpy as np |
|
|
import os, time, sys |
|
|
from datetime import datetime as dt |
|
|
from datetime import timedelta |
|
|
from datetime import timezone |
|
|
import io |
|
|
import requests |
|
|
import torch.nn.functional as F |
|
|
import torch |
|
|
from torch import Tensor |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qdrant_api_key = os.environ.get("qdrant_api_key") |
|
|
qdrant_url = os.environ.get("qdrant_url") |
|
|
|
|
|
client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def average_pool(last_hidden_states: Tensor, |
|
|
attention_mask: Tensor) -> Tensor: |
|
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2') |
|
|
model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device) |
|
|
|
|
|
def e5embed(query): |
|
|
batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt') |
|
|
batch_dict = {k: v.to(device) for k, v in batch_dict.items()} |
|
|
outputs = model(**batch_dict) |
|
|
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) |
|
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
|
embeddings = embeddings.cpu().detach().numpy().flatten().tolist() |
|
|
return embeddings |
|
|
|
|
|
def get_id(collection): |
|
|
resp = client.scroll(collection_name=collection, limit=10000, with_payload=True, with_vectors=False,) |
|
|
max_id = max([r.id for r in resp[0]])+1 |
|
|
return int(max_id) |
|
|
|
|
|
@app.route("/") |
|
|
def index(): |
|
|
return render_template("index.html") |
|
|
|
|
|
@app.route("/search", methods=["POST"]) |
|
|
def search(): |
|
|
query = request.form["query"] |
|
|
collection_name = request.form["collection"] |
|
|
topN = 200 |
|
|
|
|
|
|
|
|
print('QUERY: ',query) |
|
|
if query.strip().startswith('tilc:'): |
|
|
collection_name = 'tils' |
|
|
qvector = "context" |
|
|
query = query.replace('tilc:', '') |
|
|
elif query.strip().startswith('til:'): |
|
|
collection_name = 'tils' |
|
|
qvector = "title" |
|
|
query = query.replace('til:', '') |
|
|
else: collection_name = 'jks' |
|
|
|
|
|
timh = time.time() |
|
|
sq = e5embed(query) |
|
|
print('EMBEDDING TIME: ', time.time() - timh) |
|
|
|
|
|
timh = time.time() |
|
|
if collection_name == "jks": |
|
|
data = {"vector": sq, "with_payload": True, "limit": topN} |
|
|
response = requests.post(qdrant_url+f'/jks/points/search', json=data, headers={'Content-Type': 'application/json'}) |
|
|
results = response.json() |
|
|
|
|
|
|
|
|
else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100) |
|
|
print('SEARCH TIME: ', time.time() - timh) |
|
|
|
|
|
|
|
|
|
|
|
new_results = [] |
|
|
if collection_name == 'jks': |
|
|
for r in results: |
|
|
if 'date' not in r['payload']: r['payload']['date'] = '20200101' |
|
|
new_results.append({"text": r['payload']['text'], "date": str(int(r['payload']['date'])), "id": r['id']}) |
|
|
else: |
|
|
for r in results: |
|
|
if 'context' in r.payload and r.payload['context'] != '': |
|
|
if 'date' not in r.payload: r.payload['date'] = '20200101' |
|
|
new_results.append({"text": r.payload['title'] + '<br>Context: ' + r.payload['context'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id}) |
|
|
else: |
|
|
if 'date' not in r.payload: r.payload['date'] = '20200101' |
|
|
new_results.append({"text": r.payload['title'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id}) |
|
|
return jsonify(new_results) |
|
|
|
|
|
|
|
|
|
|
|
@app.route("/add_item", methods=["POST"]) |
|
|
def add_item(): |
|
|
title = request.form["title"] |
|
|
url = request.form["url"] |
|
|
if url.strip() == '': |
|
|
collection_name = 'jks' |
|
|
cid = get_id(collection_name) |
|
|
print('cid', cid, time.strftime("%Y%m%d")) |
|
|
resp = client.upsert(collection_name=collection_name, points=Batch(ids=[cid], payloads=[{'text':title, 'date': time.strftime("%Y%m%d")}],vectors=[e5embed(title)]),) |
|
|
else: |
|
|
collection_name = 'tils' |
|
|
cid = get_id('tils') |
|
|
print('cid', cid, time.strftime("%Y%m%d"), collection_name) |
|
|
til = {'title': title.replace('TIL that', '').replace('TIL:', '').replace('TIL ', '').strip(), 'url': url.replace('https://', '').replace('http://', ''), "date": time.strftime("%Y%m%d_%H%M")} |
|
|
resp = client.upsert(collection_name="tils", points=[PointStruct(id=cid, payload=til, vector={"title": e5embed(til['title']),},)]) |
|
|
print('Upsert response:', resp) |
|
|
return jsonify({"success": True, "index": collection_name}) |
|
|
|
|
|
|
|
|
@app.route("/delete_joke", methods=["POST"]) |
|
|
def delete_joke(): |
|
|
joke_id = request.form["id"] |
|
|
collection_name = request.form["collection"] |
|
|
print('Deleting no.', joke_id, 'from collection', collection_name) |
|
|
client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),) |
|
|
return jsonify({"deleted": True}) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.run(host="0.0.0.0", debug=True, port=7860) |