|
|
import logging |
|
|
from datetime import datetime, timedelta |
|
|
|
|
|
from flask import request, Response, current_app as app |
|
|
|
|
|
from app.config import IGNORED_MODEL_NAMES, IMAGE_MODEL_NAMES, AUTH_TOKEN, HISTORY_MSG_LIMIT |
|
|
from app.config import configure_logging |
|
|
from app.utils import send_chat_message, fetch_channel_id, map_model_name, process_content, get_user_contents, \ |
|
|
generate_hash, get_next_auth_token, handle_error, get_request_parameters |
|
|
|
|
|
configure_logging() |
|
|
storage_map = {} |
|
|
|
|
|
|
|
|
@app.route("/hf/v1/chat/completions", methods=["GET", "POST", "OPTIONS"]) |
|
|
def onRequest(): |
|
|
try: |
|
|
return fetch(request) |
|
|
except Exception as e: |
|
|
logging.error("An error occurred with chat : %s", e) |
|
|
return handle_error(e) |
|
|
|
|
|
|
|
|
@app.route('/hf/v1/models') |
|
|
def list_models(): |
|
|
return { |
|
|
"object": "list", |
|
|
"data": [{ |
|
|
"id": m, |
|
|
"object": "model", |
|
|
"created": int(datetime.now().timestamp()), |
|
|
"owned_by": "popai" |
|
|
} for m in IGNORED_MODEL_NAMES] |
|
|
} |
|
|
|
|
|
|
|
|
@app.route('/hf/v1/images/generations', methods= ["post"]) |
|
|
def image(): |
|
|
try: |
|
|
request.get_json()["model"] = IMAGE_MODEL_NAMES[0] |
|
|
return fetch(request) |
|
|
except Exception as e: |
|
|
logging.error("An error occurred with image : %s", e) |
|
|
return handle_error(e) |
|
|
|
|
|
|
|
|
def get_channel_id(hash_value, token, model_name, content, template_id): |
|
|
if hash_value in storage_map: |
|
|
channel_id, expiry_time = storage_map[hash_value] |
|
|
if expiry_time > datetime.now() and channel_id: |
|
|
logging.info("Returning channel id from cache") |
|
|
return channel_id |
|
|
channel_id = fetch_channel_id(token, model_name, content, template_id) |
|
|
expiry_time = datetime.now() + timedelta(days=1) |
|
|
storage_map[hash_value] = (channel_id, expiry_time) |
|
|
return channel_id |
|
|
|
|
|
|
|
|
def fetch(req): |
|
|
if req.method == "OPTIONS": |
|
|
return handle_options_request() |
|
|
token = req.headers.get("Authorization").replace("Bearer ", "") |
|
|
messages, model_name, prompt, user_stream = get_request_parameters(req.get_json()) |
|
|
model_to_use = map_model_name(model_name) |
|
|
template_id = 2000000 if model_name in IMAGE_MODEL_NAMES else '' |
|
|
|
|
|
if not messages and prompt: |
|
|
final_user_content = prompt |
|
|
first_user_message = final_user_content |
|
|
image_url = None |
|
|
elif messages: |
|
|
last_message = messages[-1] |
|
|
first_user_message, end_user_message, concatenated_messages = get_user_contents(messages, HISTORY_MSG_LIMIT) |
|
|
final_user_content, image_url = process_content(last_message.get('content')) |
|
|
final_user_content = concatenated_messages + '\n' + final_user_content if concatenated_messages else final_user_content |
|
|
|
|
|
|
|
|
hash_value = generate_hash(first_user_message, model_to_use, token) |
|
|
channel_id = get_channel_id(hash_value, token, model_to_use, final_user_content, template_id) |
|
|
|
|
|
if final_user_content is None: |
|
|
return Response("No user message found", status=400) |
|
|
|
|
|
return send_chat_message(req, token, channel_id, final_user_content, model_to_use, user_stream, image_url, model_name) |
|
|
|
|
|
|
|
|
def handle_options_request(): |
|
|
return Response(status=204, headers={'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Headers': '*'}) |
|
|
|