File size: 6,450 Bytes
8d1ecb1 531dd91 8d1ecb1 31853b5 8d1ecb1 8e95cf7 8d1ecb1 31853b5 8d1ecb1 31853b5 8d1ecb1 31853b5 531dd91 8d1ecb1 531dd91 8d1ecb1 531dd91 31853b5 531dd91 8d1ecb1 531dd91 8d1ecb1 31853b5 531dd91 8d1ecb1 531dd91 8d1ecb1 31853b5 8d1ecb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import requests
from flask import Flask, request, Response, stream_with_context, jsonify
import json
import logging
app = Flask(__name__)
logging.basicConfig(level=logging.DEBUG)
DEEPINFRA_API_URL = "https://api.deepinfra.com/v1/openai/chat/completions"
API_KEY = os.environ.get("API_KEY")
def authenticate():
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return False
token = auth_header.split(" ")[1]
return token == API_KEY
@app.route('/hf/v1/chat/completions', methods=['POST'])
def chat_completions():
logging.debug(f"Received headers: {request.headers}")
logging.debug(f"Received body: {request.get_data(as_text=True)}")
if not authenticate():
logging.warning("Unauthorized access attempt")
return jsonify({"error": "Unauthorized"}), 401
try:
openai_request = request.json
except json.JSONDecodeError:
logging.error("Invalid JSON in request body")
return jsonify({"error": "Invalid JSON in request body"}), 400
logging.info(f"Received request: {openai_request}")
deepinfra_request = {
"model": openai_request.get("model", "meta-llama/Meta-Llama-3.1-405B-Instruct"),
"temperature": openai_request.get("temperature", 0.7),
"max_tokens": openai_request.get("max_tokens", 1000),
"stream": openai_request.get("stream", False),
"messages": openai_request.get("messages", [])
}
headers = {
"Content-Type": "application/json",
"Accept": "text/event-stream" if deepinfra_request["stream"] else "application/json"
}
try:
response = requests.post(DEEPINFRA_API_URL, json=deepinfra_request, headers=headers, stream=deepinfra_request["stream"])
response.raise_for_status()
logging.debug(f"DeepInfra API response status: {response.status_code}")
logging.debug(f"DeepInfra API response headers: {response.headers}")
except requests.RequestException as e:
logging.error(f"Error calling DeepInfra API: {str(e)}")
return jsonify({"error": "Failed to call DeepInfra API"}), 500
if deepinfra_request["stream"]:
def generate():
full_content = ""
for line in response.iter_lines():
if not line:
logging.warning("Received empty line from DeepInfra API")
continue
try:
line_text = line.decode('utf-8')
if line_text.startswith('data: '):
data_text = line_text.split('data: ', 1)[1]
if data_text == "[DONE]":
yield f"data: [DONE]\n\n"
break
data = json.loads(data_text)
delta_content = data['choices'][0]['delta'].get('content', '')
full_content += delta_content
openai_format = {
"id": data['id'],
"object": "chat.completion.chunk",
"created": data['created'],
"model": data['model'],
"choices": [
{
"index": 0,
"delta": {
"content": delta_content
},
"finish_reason": data['choices'][0].get('finish_reason')
}
]
}
yield f"data: {json.dumps(openai_format)}\n\n"
except json.JSONDecodeError as e:
logging.error(f"JSON decode error: {e}. Raw line: {line}")
continue
except Exception as e:
logging.error(f"Error processing line: {e}. Raw line: {line}")
continue
# Send the final usage information
if 'usage' in data:
final_chunk = {
"id": data['id'],
"object": "chat.completion.chunk",
"created": data['created'],
"model": data['model'],
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"usage": data['usage']
}
yield f"data: {json.dumps(final_chunk)}\n\n"
return Response(stream_with_context(generate()), content_type='text/event-stream')
else:
try:
deepinfra_response = response.json()
logging.info(f"Received response from DeepInfra: {deepinfra_response}")
if 'error' in deepinfra_response:
return jsonify({"error": deepinfra_response['error']}), 400
if 'choices' not in deepinfra_response or not deepinfra_response['choices']:
return jsonify({"error": "Unexpected response format from DeepInfra"}), 500
openai_response = {
"id": deepinfra_response.get("id", ""),
"object": "chat.completion",
"created": deepinfra_response.get("created", 0),
"model": deepinfra_response.get("model", ""),
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": deepinfra_response["choices"][0]["message"]["content"]
},
"finish_reason": deepinfra_response["choices"][0].get("finish_reason", "stop")
}
],
"usage": deepinfra_response.get("usage", {})
}
return json.dumps(openai_response), 200, {'Content-Type': 'application/json'}
except Exception as e:
logging.error(f"Error processing DeepInfra response: {str(e)}")
return jsonify({"error": "Failed to process DeepInfra response"}), 500
@app.route('/')
def home():
return "Welcome to the API proxy server. Please use the /hf/v1/chat/completions endpoint for chat completions."
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860) |