maceter636 commited on
Commit
de1c16d
·
1 Parent(s): 6dd18f5

Create server.py

Browse files
Files changed (1) hide show
  1. server.py +132 -0
server.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import (
2
+ Flask,
3
+ jsonify,
4
+ request,
5
+ render_template_string,
6
+ abort,
7
+ )
8
+ from flask_cors import CORS
9
+ import unicodedata
10
+ import markdown
11
+ import time
12
+ import os
13
+ import gc
14
+ import base64
15
+ from io import BytesIO
16
+ from random import randint
17
+ import hashlib
18
+ from colorama import Fore, Style, init as colorama_init
19
+ import chromadb
20
+ import posthog
21
+ from chromadb.config import Settings
22
+ from sentence_transformers import SentenceTransformer
23
+
24
+ colorama_init()
25
+
26
+ port = 7860
27
+ host = "0.0.0.0"
28
+
29
+ embedding_model = 'sentence-transformers/all-mpnet-base-v2'
30
+
31
+ print("Initializing ChromaDB")
32
+
33
+ # disable chromadb telemetry
34
+ posthog.capture = lambda *args, **kwargs: None
35
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
36
+ chromadb_embedder = SentenceTransformer(embedding_model)
37
+ chromadb_embed_fn = chromadb_embedder.encode
38
+
39
+ # Flask init
40
+ app = Flask(__name__)
41
+ CORS(app) # allow cross-domain requests
42
+ app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
43
+
44
+ def get_real_ip():
45
+ return request.environ.get('HTTP_X_FORWARDED_FOR', request.remote_addr)
46
+
47
+ @app.route("/", methods=["GET"])
48
+ def index():
49
+ with open("./README.md", "r", encoding="utf8") as f:
50
+ content = f.read()
51
+ return render_template_string(markdown.markdown(content, extensions=["tables"]))
52
+
53
+
54
+ @app.route("/api/modules", methods=["GET"])
55
+ def get_modules():
56
+ return jsonify({"modules": ['chromadb']})
57
+
58
+ @app.route("/api/chromadb", methods=["POST"])
59
+ def chromadb_add_messages():
60
+ data = request.get_json()
61
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
62
+ abort(400, '"chat_id" is required')
63
+ if "messages" not in data or not isinstance(data["messages"], list):
64
+ abort(400, '"messages" is required')
65
+
66
+ ip = get_real_ip()
67
+ chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
68
+ collection = chromadb_client.get_or_create_collection(
69
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
70
+ )
71
+
72
+ documents = [m["content"] for m in data["messages"]]
73
+ ids = [m["id"] for m in data["messages"]]
74
+ metadatas = [
75
+ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
76
+ for m in data["messages"]
77
+ ]
78
+
79
+ collection.upsert(
80
+ ids=ids,
81
+ documents=documents,
82
+ metadatas=metadatas,
83
+ )
84
+
85
+ return jsonify({"count": len(ids)})
86
+
87
+
88
+ @app.route("/api/chromadb/query", methods=["POST"])
89
+ def chromadb_query():
90
+ data = request.get_json()
91
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
92
+ abort(400, '"chat_id" is required')
93
+ if "query" not in data or not isinstance(data["query"], str):
94
+ abort(400, '"query" is required')
95
+
96
+ if "n_results" not in data or not isinstance(data["n_results"], int):
97
+ n_results = 1
98
+ else:
99
+ n_results = data["n_results"]
100
+
101
+ ip = get_real_ip()
102
+ chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
103
+ collection = chromadb_client.get_or_create_collection(
104
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
105
+ )
106
+
107
+ n_results = min(collection.count(), n_results)
108
+ query_result = collection.query(
109
+ query_texts=[data["query"]],
110
+ n_results=n_results,
111
+ )
112
+
113
+ documents = query_result["documents"][0]
114
+ ids = query_result["ids"][0]
115
+ metadatas = query_result["metadatas"][0]
116
+ distances = query_result["distances"][0]
117
+
118
+ messages = [
119
+ {
120
+ "id": ids[i],
121
+ "date": metadatas[i]["date"],
122
+ "role": metadatas[i]["role"],
123
+ "meta": metadatas[i]["meta"],
124
+ "content": documents[i],
125
+ "distance": distances[i],
126
+ }
127
+ for i in range(len(ids))
128
+ ]
129
+
130
+ return jsonify(messages)
131
+
132
+ app.run(host=host, port=port)