ethiotech4848 commited on
Commit
6466f74
·
verified ·
1 Parent(s): 66c69bb

Create deepinfra_handler.py

Browse files
Files changed (1) hide show
  1. deepinfra_handler.py +65 -0
deepinfra_handler.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+ from typing import Dict, Any, Generator, Optional
4
+
5
+ class DeepInfraHandler:
6
+ API_URL = "https://api.deepinfra.com/v1/openai/chat/completions"
7
+
8
+ def __init__(self):
9
+ self.headers = {
10
+ "Accept": "text/event-stream",
11
+ "Accept-Encoding": "gzip, deflate, br, zstd",
12
+ "Content-Type": "application/json",
13
+ "Connection": "keep-alive",
14
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36",
15
+ }
16
+
17
+ def _prepare_payload(self, **kwargs) -> Dict[str, Any]:
18
+ """Prepare the payload for the API request"""
19
+ return {
20
+ "model": kwargs.get("model"),
21
+ "messages": kwargs.get("messages"),
22
+ "temperature": kwargs.get("temperature", 0.7),
23
+ "max_tokens": kwargs.get("max_tokens", 4096),
24
+ "top_p": kwargs.get("top_p", 1.0),
25
+ "frequency_penalty": kwargs.get("frequency_penalty", 0.0),
26
+ "presence_penalty": kwargs.get("presence_penalty", 0.0),
27
+ "stop": kwargs.get("stop", []),
28
+ "stream": kwargs.get("stream", False)
29
+ }
30
+
31
+ def generate_completion(self, **kwargs) -> Any:
32
+ """Generate completion based on streaming preference"""
33
+ payload = self._prepare_payload(**kwargs)
34
+
35
+ response = requests.post(
36
+ self.API_URL,
37
+ headers=self.headers,
38
+ json=payload,
39
+ stream=payload["stream"]
40
+ )
41
+
42
+ if payload["stream"]:
43
+ return self._handle_streaming_response(response)
44
+ return self._handle_regular_response(response)
45
+
46
+ def _handle_streaming_response(self, response) -> Generator[str, None, None]:
47
+ """Handle streaming response from the API"""
48
+ for line in response.iter_lines(decode_unicode=True):
49
+ if line.startswith("data:"):
50
+ try:
51
+ content = json.loads(line[5:])
52
+ if content == "[DONE]":
53
+ continue
54
+ delta_content = content.get("choices", [{}])[0].get("delta", {}).get("content")
55
+ if delta_content:
56
+ yield delta_content
57
+ except:
58
+ continue
59
+
60
+ def _handle_regular_response(self, response) -> Dict[str, Any]:
61
+ """Handle regular (non-streaming) response from the API"""
62
+ try:
63
+ return response.json()
64
+ except Exception as e:
65
+ raise Exception(f"Error processing response: {str(e)}")