| from flask import ( |
| Flask, |
| jsonify, |
| request, |
| render_template_string, |
| abort, |
| ) |
| from flask_cors import CORS |
| import unicodedata |
| import markdown |
| import time |
| import os |
| import gc |
| import base64 |
| from io import BytesIO |
| from random import randint |
| import hashlib |
| from colorama import Fore, Style, init as colorama_init |
| import chromadb |
| import posthog |
| from chromadb.config import Settings |
| from sentence_transformers import SentenceTransformer |
| from werkzeug.middleware.proxy_fix import ProxyFix |
|
|
| colorama_init() |
|
|
| port = 7860 |
| host = "0.0.0.0" |
|
|
| embedding_model = 'sentence-transformers/all-mpnet-base-v2' |
|
|
| print("Initializing ChromaDB") |
|
|
| |
| posthog.capture = lambda *args, **kwargs: None |
| chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False)) |
| chromadb_embedder = SentenceTransformer(embedding_model) |
| chromadb_embed_fn = chromadb_embedder.encode |
|
|
| |
| app = Flask(__name__) |
| CORS(app) |
| app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024 |
|
|
| app.wsgi_app = ProxyFix( |
| app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1 |
| ) |
|
|
| def get_real_ip(): |
| return request.remote_addr |
|
|
| @app.route("/", methods=["GET"]) |
| def index(): |
| with open("./README.md", "r", encoding="utf8") as f: |
| content = f.read() |
| return render_template_string(markdown.markdown(content, extensions=["tables"])) |
|
|
|
|
| @app.route("/api/modules", methods=["GET"]) |
| def get_modules(): |
| return jsonify({"modules": ['chromadb']}) |
|
|
| @app.route("/api/chromadb", methods=["POST"]) |
| def chromadb_add_messages(): |
| data = request.get_json() |
| if "chat_id" not in data or not isinstance(data["chat_id"], str): |
| abort(400, '"chat_id" is required') |
| if "messages" not in data or not isinstance(data["messages"], list): |
| abort(400, '"messages" is required') |
|
|
| ip = get_real_ip() |
| chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() |
| collection = chromadb_client.get_or_create_collection( |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
| ) |
|
|
| documents = [m["content"] for m in data["messages"]] |
| ids = [m["id"] for m in data["messages"]] |
| metadatas = [ |
| {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")} |
| for m in data["messages"] |
| ] |
|
|
| if len(ids) > 0: |
| collection.upsert( |
| ids=ids, |
| documents=documents, |
| metadatas=metadatas, |
| ) |
|
|
| return jsonify({"count": len(ids)}) |
|
|
|
|
| @app.route("/api/chromadb/query", methods=["POST"]) |
| def chromadb_query(): |
| data = request.get_json() |
| if "chat_id" not in data or not isinstance(data["chat_id"], str): |
| abort(400, '"chat_id" is required') |
| if "query" not in data or not isinstance(data["query"], str): |
| abort(400, '"query" is required') |
|
|
| if "n_results" not in data or not isinstance(data["n_results"], int): |
| n_results = 1 |
| else: |
| n_results = data["n_results"] |
|
|
| ip = get_real_ip() |
| chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() |
| collection = chromadb_client.get_or_create_collection( |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
| ) |
|
|
| n_results = min(collection.count(), n_results) |
|
|
| messages = [] |
| if n_results > 0: |
| query_result = collection.query( |
| query_texts=[data["query"]], |
| n_results=n_results, |
| ) |
| |
| documents = query_result["documents"][0] |
| ids = query_result["ids"][0] |
| metadatas = query_result["metadatas"][0] |
| distances = query_result["distances"][0] |
| |
| messages = [ |
| { |
| "id": ids[i], |
| "date": metadatas[i]["date"], |
| "role": metadatas[i]["role"], |
| "meta": metadatas[i]["meta"], |
| "content": documents[i], |
| "distance": distances[i], |
| } |
| for i in range(len(ids)) |
| ] |
|
|
| return jsonify(messages) |
| |
| @app.route("/api/chromadb/purge", methods=["POST"]) |
| def chromadb_purge(): |
| data = request.get_json() |
| if "chat_id" not in data or not isinstance(data["chat_id"], str): |
| abort(400, '"chat_id" is required') |
|
|
| ip = get_real_ip() |
| chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() |
| collection = chromadb_client.get_or_create_collection( |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
| ) |
|
|
| deleted = collection.delete() |
| print("ChromaDB embeddings deleted", len(deleted)) |
|
|
| return 'Ok', 200 |
| app.run(host=host, port=port) |