Spaces:
Paused
Paused
| from flask import Flask, request, abort, Response | |
| from werkzeug.security import generate_password_hash, check_password_hash | |
| from werkzeug.exceptions import HTTPException | |
| import os, threading, json, waitress, datetime, traceback | |
| from llama_cpp import Llama | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import sentry_sdk | |
| from flask import Flask | |
| from sentry_sdk.integrations.flask import FlaskIntegration | |
| sentry_sdk.init( | |
| dsn="https://5dcf8a99012c4c86b9b1f0293f6b4c2e@o4505516024004608.ingest.sentry.io/4505541971935232", | |
| integrations=[ | |
| FlaskIntegration(), | |
| ], | |
| # Set traces_sample_rate to 1.0 to capture 100% | |
| # of transactions for performance monitoring. | |
| # We recommend adjusting this value in production. | |
| traces_sample_rate=1.0 | |
| ) | |
| #Variables | |
| DEBUGMODEENABLED = (os.getenv('debugModeEnabled', 'False') == 'True') | |
| modelName = "vicuna" | |
| llm = None | |
| AlpacaLoaded = False | |
| #Chat Functions | |
| def load_alpaca(): | |
| global llm, AlpacaLoaded, modelName | |
| if not AlpacaLoaded: | |
| print("Loading Alpaca...") | |
| try: | |
| llm = Llama(model_path=f"./resources/{modelName}-ggml-model-q4.bin", use_mmap=False, n_threads=2, verbose=False, n_ctx=2048) #use_mlock=True | |
| AlpacaLoaded = True | |
| print("Done loading Alpaca.") | |
| return "Done" | |
| except AttributeError: | |
| print("Error loading Alpaca. Please make sure you have the model file in the resources folder.") | |
| return "Error" | |
| else: | |
| print("Alpaca already loaded.") | |
| return "Already Loaded" | |
| def getChatResponse(modelOutput): | |
| return str(modelOutput["choices"][0]['message']['content']) | |
| def reload_alpaca(): | |
| global llm, AlpacaLoaded, modelName | |
| if AlpacaLoaded: | |
| llm = None | |
| input("Pleease confirm that the memory is cleared!") | |
| AlpacaLoaded = False | |
| load_alpaca() | |
| return "Done" | |
| #Authentication Functions | |
| def loadHashes(): | |
| global hashesDict | |
| try: | |
| with open("resources/hashes.json", "r") as f: | |
| hashesDict = json.load(f) | |
| except FileNotFoundError: | |
| hashesDict = {} | |
| def saveHashes(): | |
| global hashesDict | |
| with open("resources/hashes.json", "w") as f: | |
| json.dump(hashesDict, f) | |
| def addHashes(username: str, password: str): | |
| global hashesDict | |
| hashesDict[username] = generate_password_hash(password, method='scrypt') | |
| saveHashes() | |
| def checkCredentials(username: str , password: str): | |
| global hashesDict | |
| if username in hashesDict: | |
| return check_password_hash(hashesDict[username], password) | |
| else: | |
| return False | |
| def verifyHeaders(): | |
| #Check + Obtain Authorization header | |
| try: | |
| user, passw = request.headers['Authorization'].split(":") | |
| except (KeyError, ValueError): | |
| abort(401) | |
| #Check if Authorization header is valid | |
| credentialsValid = checkCredentials(user, passw) | |
| if not credentialsValid: | |
| abort(403) | |
| else: | |
| return user | |
| loadHashes() | |
| #addHashes("test", "test") | |
| #General Functions | |
| def getIsoTime(): | |
| return str(datetime.datetime.now().isoformat()) | |
| #Flask App | |
| app = Flask(__name__) | |
| def main(): | |
| return """<!DOCTYPE HTML> | |
| <html> | |
| <head><meta name='color-scheme' content='dark'></head> | |
| <body><p>Hello, World!</p></body> | |
| </html>""" | |
| def chat(): | |
| if request.method == 'POST': | |
| print("Chat Completion Requested.") | |
| verifyHeaders() | |
| print("Headers verified") | |
| messages = request.get_json() | |
| print("Got Message" + str(messages)) | |
| if AlpacaLoaded: | |
| modelOutput = llm.create_chat_completion(messages=messages, max_tokens=1024) | |
| responseMessage = modelOutput["choices"][0]['message'] | |
| print(f"\n\nResponseMessage: {responseMessage}\n\n") | |
| return Response(json.dumps(responseMessage, indent=2), content_type='application/json') | |
| else: | |
| print("Alpaca not loaded. ") | |
| abort(503, "Alpaca not loaded. Please wait a few seconds and try again.") | |
| else: | |
| return "Ready" if AlpacaLoaded else "Not Ready", 200 if AlpacaLoaded else 503 | |
| def trigger_error(): | |
| division_by_zero = 1 / 0 | |
| def handle_exception(e): | |
| errorInfo = json.dumps({"error": f"{e.code} - {e.name}", "message": e.description}, indent=2) | |
| return Response(errorInfo, content_type='application/json'), e.code | |
| def handle_errors(e): | |
| print(f"INTERNAL SERVER ERROR 500 @ {request.path}") | |
| exceptionInfo = f"{type(e).__name__}: {str(e)}" | |
| errorTraceback = traceback.format_exc() | |
| print(errorTraceback) | |
| sentry_sdk.capture_exception(e) | |
| errorInfo = json.dumps({"error": f"500 - Internal Server Error", "message": exceptionInfo}, indent=2) | |
| return Response(errorInfo, content_type='application/json'), 500 | |
| if __name__ == '__main__': | |
| threading.Thread(target=load_alpaca, daemon=True).start() | |
| port = int(os.getenv("port", "8080")) | |
| print("Server successfully started.") | |
| if DEBUGMODEENABLED: | |
| app.run(host='0.0.0.0', port=port) | |
| else: | |
| waitress.serve(app, host='0.0.0.0', port=port, url_scheme='https') |