antoncio commited on
Commit
c1fe6d3
·
1 Parent(s): a642e47

init commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && \
6
+ apt-get install -y --no-install-recommends git g++ make wget && \
7
+ apt-get clean && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ COPY requirements.txt .
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ COPY . .
14
+
15
+ EXPOSE 7860
16
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
  title: General Agent
3
- emoji: 🔥
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: General Agent
3
+ emoji: 🏃
4
+ colorFrom: red
5
+ colorTo: yellow
6
  sdk: docker
7
  pinned: false
8
+ short_description: First attempt to build and expose agent
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
main.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as t
2
+ import os
3
+ # from dotenv import load_dotenv
4
+ from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel
6
+ import logging
7
+
8
+ from src.utils import (
9
+ OpenAIClient,
10
+ TogetherAIClient,
11
+ GeminiClient,
12
+ GroqClient,
13
+ MistralClient,
14
+ )
15
+ from src.models_enums import ModelProvider
16
+
17
+ # load_dotenv()
18
+
19
+ assert os.environ['TOGETHER_API_KEY'] is not None
20
+
21
+ # Configure basic logging to see messages in stdout (and thus in HF Space logs)
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ class RequestData(BaseModel):
26
+ prompt: str
27
+ max_tokens: int = 50
28
+ system_prompt: t.Optional[str] = None
29
+
30
+ MODEL_PROVIDER2CLIENT = {
31
+ ModelProvider.OPENAI.value: OpenAIClient,
32
+ ModelProvider.GEMINI.value: GeminiClient,
33
+ ModelProvider.TOGETHERAI.value: TogetherAIClient,
34
+ ModelProvider.GROQ.value: GroqClient,
35
+ ModelProvider.MISTRAL.value: MistralClient,
36
+ }
37
+
38
+
39
+ app = FastAPI()
40
+ logger.info("FastAPI app initialized.")
41
+
42
+
43
+ # The application now starts without initializing a specific LLM,
44
+ # which makes it more flexible.
45
+
46
+ @app.post("/generate/{model_provider}/{model_name:path}")
47
+ async def generate_text(
48
+ model_provider: str,
49
+ model_name: str,
50
+ request: RequestData
51
+ ):
52
+ """
53
+ Generates text using a specified LLM provider and model.
54
+
55
+ Example:
56
+ POST /generate/togetherai/meta-llama/Llama-3.3-70B-Instruct-Turbo-Free
57
+ with body: {"prompt": "...", "max_tokens": 100}
58
+ """
59
+ logger.info(f"Received POST request to /generate/{model_provider}/{model_name}.")
60
+
61
+ # Check if the requested model provider exists
62
+ if model_provider not in MODEL_PROVIDER2CLIENT:
63
+ logger.error(f"Invalid model provider: {model_provider}")
64
+ raise HTTPException(
65
+ status_code=400,
66
+ detail=f"Invalid model provider: {model_provider}. "
67
+ f"Available providers: {[p.value for p in ModelProvider]}"
68
+ )
69
+
70
+ try:
71
+ # Get the correct client class and instantiate it dynamically
72
+ llm_client_class = MODEL_PROVIDER2CLIENT[model_provider]
73
+ llm_client = llm_client_class(model=model_name)
74
+
75
+ # Call the client's async method
76
+ output = await llm_client(
77
+ prompt=request.prompt,
78
+ system_prompt=request.system_prompt,
79
+ max_tokens=request.max_tokens
80
+ )
81
+
82
+ return output
83
+
84
+ except Exception as e:
85
+ logger.error(
86
+ f"Error during text generation for {model_provider}/{model_name}: {str(e)}",
87
+ exc_info=True
88
+ )
89
+ raise HTTPException(status_code=500, detail=str(e))
90
+
91
+ @app.get("/health")
92
+ async def health_check():
93
+ logger.info("Received GET request to /health.")
94
+ return {"status": "ok"}
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai
2
+ fastapi==0.109.0
3
+ uvicorn==0.27.0
4
+ torch==2.2.1 --index-url https://download.pytorch.org/whl/cpu
5
+ transformers==4.40.2
6
+ accelerate==0.29.3
7
+ sentencepiece==0.2.0
8
+ numpy==1.26.4
9
+ protobuf==3.20.3
10
+ python-dotenv
11
+ together
src/__init__.py ADDED
File without changes
src/__pycache__/utils.cpython-312.pyc ADDED
Binary file (3.91 kB). View file
 
src/models_enums.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ class ModelProvider(Enum):
4
+ OPENAI = 'openai'
5
+ GEMINI = 'gemini'
6
+ MISTRAL = 'mistral'
7
+ TOGETHERAI = 'togetherai'
8
+ GROQ = 'groq'
9
+
src/utils.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as t
2
+ import asyncio
3
+ from abc import ABC, abstractmethod
4
+ import os
5
+
6
+ # External Libraries
7
+ import requests
8
+ import json
9
+ from together import Together
10
+ from openai import AsyncOpenAI
11
+ import aiohttp # Using aiohttp for async HTTP requests
12
+
13
+ # A standard response type to avoid type errors with Together.
14
+ try:
15
+ from together.types.chat_completions import ChatCompletionResponse
16
+ except ImportError:
17
+ ChatCompletionResponse = t.Any
18
+
19
+
20
+ # --- ABSTRACT BASE CLASS (The Core Abstraction) ---
21
+ class BaseLLMClient(ABC):
22
+ """
23
+ Abstract base class for all LLM clients.
24
+ Defines a common, asynchronous interface for calling LLMs.
25
+ """
26
+ def __init__(self, model: str, **kwargs):
27
+ self.client = None
28
+ self.model = model
29
+ self.kwargs = kwargs
30
+
31
+ @abstractmethod
32
+ async def __call__(
33
+ self,
34
+ prompt: str,
35
+ max_tokens: int = 1_000,
36
+ system_prompt: t.Optional[str] = None,
37
+ **kwargs
38
+ ) -> str:
39
+ """
40
+ Defines the async call method for all concrete clients.
41
+ """
42
+ ...
43
+
44
+ def _create_messages(
45
+ self,
46
+ prompt: str,
47
+ system_prompt: t.Optional[str] = None
48
+ ) -> t.List[t.Dict[str, str]]:
49
+ """
50
+ Helper to create a standard message dictionary with an optional system prompt.
51
+ The system prompt is always added as the first message to set the model's context.
52
+ """
53
+ messages = []
54
+ if system_prompt:
55
+ messages.append({"role": "system", "content": system_prompt})
56
+ messages.append({"role": "user", "content": prompt})
57
+ return messages
58
+
59
+
60
+ # --- CONCRETE IMPLEMENTATIONS (The Adapters) ---
61
+
62
+ class OpenAIClient(BaseLLMClient):
63
+ """
64
+ Adapter for the OpenAI (and OpenAI-compatible) Async API client.
65
+ """
66
+ def __init__(self, model: str, **kwargs):
67
+ super().__init__(model, **kwargs)
68
+ self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
69
+
70
+ async def __call__(
71
+ self,
72
+ prompt: str,
73
+ max_tokens: int = 1_000,
74
+ system_prompt: t.Optional[str] = None,
75
+ **kwargs
76
+ ) -> str:
77
+ try:
78
+ messages = self._create_messages(prompt, system_prompt)
79
+ response = await self.client.chat.completions.create(
80
+ model=self.model,
81
+ messages=messages,
82
+ max_tokens=max_tokens,
83
+ **self.kwargs,
84
+ **kwargs,
85
+ )
86
+ return response.choices[0].message.content
87
+ except Exception as e:
88
+ return f"Error from OpenAI: {e}"
89
+
90
+
91
+ class TogetherAIClient(BaseLLMClient):
92
+ """
93
+ Adapter for the Together API client.
94
+ Uses asyncio.to_thread to run the synchronous client in a separate thread.
95
+ """
96
+ def __init__(self, model: str, **kwargs):
97
+ super().__init__(model, **kwargs)
98
+ # Note: Together() automatically looks for TOGETHER_API_KEY env var
99
+ self.client = Together()
100
+
101
+ async def __call__(
102
+ self,
103
+ prompt: str,
104
+ max_tokens: int = 1_000,
105
+ system_prompt: t.Optional[str] = None,
106
+ **kwargs
107
+ ) -> str:
108
+ # Use asyncio.to_thread to run the synchronous Together client
109
+ # without blocking the event loop.
110
+ try:
111
+ messages = self._create_messages(prompt, system_prompt)
112
+ response: ChatCompletionResponse = await asyncio.to_thread(
113
+ self.client.chat.completions.create,
114
+ model=self.model,
115
+ messages=messages,
116
+ max_tokens=max_tokens,
117
+ **self.kwargs,
118
+ **kwargs,
119
+ )
120
+ return str(response.choices[0].message.content)
121
+ except Exception as e:
122
+ return f"Error from TogetherAI: {e}"
123
+
124
+
125
+ class GeminiClient(BaseLLMClient):
126
+ """
127
+ Adapter for the Gemini REST API, using aiohttp for async HTTP requests.
128
+ """
129
+ def __init__(self, model: str, **kwargs):
130
+ super().__init__(model, **kwargs)
131
+ self.api_key = os.getenv("GEMINI_API_KEY")
132
+ self.url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent?key={self.api_key}"
133
+
134
+ async def __call__(
135
+ self,
136
+ prompt: str,
137
+ max_tokens: int = 1_000,
138
+ system_prompt: t.Optional[str] = None,
139
+ **kwargs
140
+ ) -> str:
141
+ if not self.api_key:
142
+ return "Error: GEMINI_API_KEY not found."
143
+
144
+ contents = self._create_messages(prompt, system_prompt)
145
+ payload = {
146
+ "contents": contents,
147
+ "generationConfig": {"maxOutputTokens": max_tokens},
148
+ **self.kwargs,
149
+ **kwargs,
150
+ }
151
+
152
+ headers = {"Content-Type": "application/json"}
153
+
154
+ try:
155
+ async with aiohttp.ClientSession() as session:
156
+ async with session.post(self.url, headers=headers, json=payload) as response:
157
+ response.raise_for_status()
158
+ response_data = await response.json()
159
+ return response_data['candidates'][0]['content']['parts'][0]['text']
160
+ except aiohttp.ClientError as e:
161
+ return f"Error from Gemini (requests): {e}"
162
+ except (KeyError, IndexError) as e:
163
+ return f"Error parsing Gemini response: {e}"
164
+
165
+
166
+ class GroqClient(BaseLLMClient):
167
+ """
168
+ Adapter for the Groq REST API, using aiohttp for async HTTP requests.
169
+ """
170
+ def __init__(self, model: str, **kwargs):
171
+ super().__init__(model, **kwargs)
172
+ self.api_key = os.getenv("GROQ_API_KEY")
173
+ self.url = "https://api.groq.com/openai/v1/chat/completions"
174
+
175
+ async def __call__(
176
+ self,
177
+ prompt: str,
178
+ max_tokens: int = 1_000,
179
+ system_prompt: t.Optional[str] = None,
180
+ **kwargs
181
+ ) -> str:
182
+ if not self.api_key:
183
+ return "Error: GROQ_API_KEY not found."
184
+
185
+ messages = self._create_messages(prompt, system_prompt)
186
+ payload = {
187
+ "model": self.model,
188
+ "messages": messages,
189
+ "max_tokens": max_tokens,
190
+ **self.kwargs,
191
+ **kwargs
192
+ }
193
+
194
+ headers = {
195
+ "Authorization": f"Bearer {self.api_key}",
196
+ "Content-Type": "application/json"
197
+ }
198
+
199
+ try:
200
+ async with aiohttp.ClientSession() as session:
201
+ async with session.post(self.url, headers=headers, json=payload) as response:
202
+ response.raise_for_status()
203
+ response_data = await response.json()
204
+ return response_data['choices'][0]['message']['content']
205
+ except aiohttp.ClientError as e:
206
+ return f"Error from Groq (requests): {e}"
207
+ except (KeyError, IndexError) as e:
208
+ return f"Error parsing Groq response: {e}"
209
+
210
+
211
+ class MistralClient(BaseLLMClient):
212
+ """
213
+ Adapter for the Mistral REST API, using aiohttp for async HTTP requests.
214
+ """
215
+ def __init__(self, model: str, **kwargs):
216
+ super().__init__(model, **kwargs)
217
+ self.api_key = os.getenv("MISTRAL_API_KEY")
218
+ self.url = "https://api.mistral.ai/v1/chat/completions"
219
+
220
+ async def __call__(
221
+ self,
222
+ prompt: str,
223
+ max_tokens: int = 1_000,
224
+ system_prompt: t.Optional[str] = None,
225
+ **kwargs
226
+ ) -> str:
227
+ if not self.api_key:
228
+ return "Error: MISTRAL_API_KEY not found."
229
+
230
+ messages = self._create_messages(prompt, system_prompt)
231
+ payload = {
232
+ "model": self.model,
233
+ "messages": messages,
234
+ "max_tokens": max_tokens,
235
+ **self.kwargs,
236
+ **kwargs
237
+ }
238
+
239
+ headers = {
240
+ "Authorization": f"Bearer {self.api_key}",
241
+ "Content-Type": "application/json"
242
+ }
243
+
244
+ try:
245
+ async with aiohttp.ClientSession() as session:
246
+ async with session.post(self.url, headers=headers, json=payload) as response:
247
+ response.raise_for_status()
248
+ response_data = await response.json()
249
+ return response_data['choices'][0]['message']['content']
250
+ except aiohttp.ClientError as e:
251
+ return f"Error from Mistral (requests): {e}"
252
+ except (KeyError, IndexError) as e:
253
+ return f"Error parsing Mistral response: {e}"
254
+
255
+
256
+ # ('openai', OpenAIClient(model="gpt-3.5-turbo")),
257
+ # ('togetherai', TogetherAIClient(model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free")),
258
+ # ('gemini', GeminiClient(model="gemini-1.5-flash-latest")),
259
+ # ('groq', GroqClient(model="llama3-8b-8192")),
260
+ # ('mistral', MistralClient(model="mistral-tiny")),