muralipala1504 commited on
Commit
a26a2bd
·
1 Parent(s): 6c2fc64

feat: add Cerebras provider as Groq failover

Browse files
deepshell-backend/deepshell/_cerebras.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+ from openai import OpenAI
4
+
5
+ CEREBRAS_HOST = "https://api.cerebras.ai/v1"
6
+ CEREBRAS_MODEL = os.getenv("CEREBRAS_MODEL", "llama3.1-8b")
7
+
8
+ class CerebrasWrapper:
9
+ def __init__(self):
10
+ api_key = os.getenv("CEREBRAS_API_KEY")
11
+ if not api_key:
12
+ raise RuntimeError("CEREBRAS_API_KEY is not set")
13
+ self.client = OpenAI(
14
+ base_url=CEREBRAS_HOST,
15
+ api_key=api_key
16
+ )
17
+
18
+ def chat(
19
+ self,
20
+ prompt: str,
21
+ max_tokens: int = 1024,
22
+ temperature: float = 0.2,
23
+ stream: bool = False,
24
+ system_prompt: Optional[str] = None
25
+ ):
26
+ from .llm import DEEPSHELL_SYSTEM_PROMPT
27
+ sys_prompt = system_prompt if system_prompt is not None else DEEPSHELL_SYSTEM_PROMPT
28
+ messages = []
29
+ if sys_prompt:
30
+ messages.append({"role": "system", "content": sys_prompt})
31
+ messages.append({"role": "user", "content": prompt})
32
+
33
+ response = self.client.chat.completions.create(
34
+ model=CEREBRAS_MODEL,
35
+ messages=messages,
36
+ max_tokens=max_tokens,
37
+ temperature=temperature,
38
+ stream=stream
39
+ )
40
+ return response
41
+
42
+ def get_cerebras_client():
43
+ return CerebrasWrapper()
deepshell-backend/deepshell/llm.py CHANGED
@@ -99,6 +99,11 @@ def get_global_client(provider: str = "groq") -> Any:
99
  if _singleton["client"] and _singleton["provider"] == provider:
100
  return _singleton["client"]
101
 
 
 
 
 
 
102
  if provider == "ollama":
103
  from deepshell._ollama import get_ollama_client
104
  _singleton["client"] = get_ollama_client()
 
99
  if _singleton["client"] and _singleton["provider"] == provider:
100
  return _singleton["client"]
101
 
102
+ if provider == "cerebras":
103
+ from deepshell._cerebras import get_cerebras_client
104
+ _singleton["client"] = get_cerebras_client()
105
+ _singleton["provider"] = provider
106
+ return _singleton["client"]
107
  if provider == "ollama":
108
  from deepshell._ollama import get_ollama_client
109
  _singleton["client"] = get_ollama_client()
docker-compose.yml CHANGED
@@ -9,4 +9,6 @@ services:
9
  - PROVIDER=${PROVIDER:-groq}
10
  - OLLAMA_HOST=${OLLAMA_HOST:-http://172.17.0.1:11434}
11
  - OLLAMA_MODEL=${OLLAMA_MODEL:-phi3:latest}
 
 
12
  restart: unless-stopped
 
9
  - PROVIDER=${PROVIDER:-groq}
10
  - OLLAMA_HOST=${OLLAMA_HOST:-http://172.17.0.1:11434}
11
  - OLLAMA_MODEL=${OLLAMA_MODEL:-phi3:latest}
12
+ - CEREBRAS_API_KEY=${CEREBRAS_API_KEY}
13
+ - CEREBRAS_MODEL=${CEREBRAS_MODEL:-llama3.1-8b}
14
  restart: unless-stopped