File size: 6,450 Bytes
8d1ecb1
 
 
 
531dd91
8d1ecb1
 
31853b5
8d1ecb1
 
 
 
 
 
 
 
 
 
 
8e95cf7
8d1ecb1
31853b5
 
 
8d1ecb1
31853b5
8d1ecb1
 
31853b5
 
 
 
 
 
531dd91
8d1ecb1
 
531dd91
8d1ecb1
 
 
 
 
 
 
 
 
 
 
531dd91
 
 
31853b5
 
531dd91
 
 
8d1ecb1
 
 
531dd91
8d1ecb1
31853b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531dd91
 
 
 
 
 
 
 
 
 
 
 
 
8d1ecb1
 
531dd91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d1ecb1
31853b5
 
 
 
8d1ecb1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import requests
from flask import Flask, request, Response, stream_with_context, jsonify
import json
import logging

app = Flask(__name__)
logging.basicConfig(level=logging.DEBUG)

DEEPINFRA_API_URL = "https://api.deepinfra.com/v1/openai/chat/completions"
API_KEY = os.environ.get("API_KEY")

def authenticate():
    auth_header = request.headers.get("Authorization")
    if not auth_header or not auth_header.startswith("Bearer "):
        return False
    token = auth_header.split(" ")[1]
    return token == API_KEY

@app.route('/hf/v1/chat/completions', methods=['POST'])
def chat_completions():
    logging.debug(f"Received headers: {request.headers}")
    logging.debug(f"Received body: {request.get_data(as_text=True)}")

    if not authenticate():
        logging.warning("Unauthorized access attempt")
        return jsonify({"error": "Unauthorized"}), 401

    try:
        openai_request = request.json
    except json.JSONDecodeError:
        logging.error("Invalid JSON in request body")
        return jsonify({"error": "Invalid JSON in request body"}), 400

    logging.info(f"Received request: {openai_request}")

    deepinfra_request = {
        "model": openai_request.get("model", "meta-llama/Meta-Llama-3.1-405B-Instruct"),
        "temperature": openai_request.get("temperature", 0.7),
        "max_tokens": openai_request.get("max_tokens", 1000),
        "stream": openai_request.get("stream", False),
        "messages": openai_request.get("messages", [])
    }

    headers = {
        "Content-Type": "application/json",
        "Accept": "text/event-stream" if deepinfra_request["stream"] else "application/json"
    }

    try:
        response = requests.post(DEEPINFRA_API_URL, json=deepinfra_request, headers=headers, stream=deepinfra_request["stream"])
        response.raise_for_status()
        logging.debug(f"DeepInfra API response status: {response.status_code}")
        logging.debug(f"DeepInfra API response headers: {response.headers}")
    except requests.RequestException as e:
        logging.error(f"Error calling DeepInfra API: {str(e)}")
        return jsonify({"error": "Failed to call DeepInfra API"}), 500

    if deepinfra_request["stream"]:
        def generate():
            full_content = ""
            for line in response.iter_lines():
                if not line:
                    logging.warning("Received empty line from DeepInfra API")
                    continue
                try:
                    line_text = line.decode('utf-8')
                    if line_text.startswith('data: '):
                        data_text = line_text.split('data: ', 1)[1]
                        if data_text == "[DONE]":
                            yield f"data: [DONE]\n\n"
                            break
                        data = json.loads(data_text)
                        
                        delta_content = data['choices'][0]['delta'].get('content', '')
                        full_content += delta_content
                        
                        openai_format = {
                            "id": data['id'],
                            "object": "chat.completion.chunk",
                            "created": data['created'],
                            "model": data['model'],
                            "choices": [
                                {
                                    "index": 0,
                                    "delta": {
                                        "content": delta_content
                                    },
                                    "finish_reason": data['choices'][0].get('finish_reason')
                                }
                            ]
                        }
                        yield f"data: {json.dumps(openai_format)}\n\n"
                except json.JSONDecodeError as e:
                    logging.error(f"JSON decode error: {e}. Raw line: {line}")
                    continue
                except Exception as e:
                    logging.error(f"Error processing line: {e}. Raw line: {line}")
                    continue
            
            # Send the final usage information
            if 'usage' in data:
                final_chunk = {
                    "id": data['id'],
                    "object": "chat.completion.chunk",
                    "created": data['created'],
                    "model": data['model'],
                    "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
                    "usage": data['usage']
                }
                yield f"data: {json.dumps(final_chunk)}\n\n"

        return Response(stream_with_context(generate()), content_type='text/event-stream')
    else:
        try:
            deepinfra_response = response.json()
            logging.info(f"Received response from DeepInfra: {deepinfra_response}")

            if 'error' in deepinfra_response:
                return jsonify({"error": deepinfra_response['error']}), 400

            if 'choices' not in deepinfra_response or not deepinfra_response['choices']:
                return jsonify({"error": "Unexpected response format from DeepInfra"}), 500

            openai_response = {
                "id": deepinfra_response.get("id", ""),
                "object": "chat.completion",
                "created": deepinfra_response.get("created", 0),
                "model": deepinfra_response.get("model", ""),
                "choices": [
                    {
                        "index": 0,
                        "message": {
                            "role": "assistant",
                            "content": deepinfra_response["choices"][0]["message"]["content"]
                        },
                        "finish_reason": deepinfra_response["choices"][0].get("finish_reason", "stop")
                    }
                ],
                "usage": deepinfra_response.get("usage", {})
            }
            return json.dumps(openai_response), 200, {'Content-Type': 'application/json'}
        except Exception as e:
            logging.error(f"Error processing DeepInfra response: {str(e)}")
            return jsonify({"error": "Failed to process DeepInfra response"}), 500

@app.route('/')
def home():
    return "Welcome to the API proxy server. Please use the /hf/v1/chat/completions endpoint for chat completions."

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)