akryldigital commited on
Commit
347831f
·
verified ·
1 Parent(s): 0714a89

remove extra adapters

Browse files
Files changed (1) hide show
  1. src/llm/adapters.py +55 -79
src/llm/adapters.py CHANGED
@@ -1,48 +1,24 @@
1
  """LLM client adapters for different providers."""
2
-
3
- from typing import Dict, Any, List, Optional, Union
4
- from abc import ABC, abstractmethod
5
  from dataclasses import dataclass
 
 
6
 
7
- # LangChain imports
8
- from langchain_mistralai.chat_models import ChatMistralAI
9
  from langchain_openai.chat_models import ChatOpenAI
10
- from langchain_ollama import ChatOllama
 
11
 
12
- # Legacy client dependencies
13
  from huggingface_hub import InferenceClient
14
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
15
  from langchain_community.llms import HuggingFaceEndpoint
 
16
  from langchain_community.chat_models.huggingface import ChatHuggingFace
17
 
18
- # Configuration loader
19
- from ..config.loader import load_config
20
 
21
- # Load configuration once at module level
22
  _config = load_config()
23
 
24
-
25
- # Legacy client factory functions (inlined from auditqa_old.reader)
26
- def _create_inf_provider_client():
27
- """Create INF_PROVIDERS client."""
28
- reader_config = _config.get("reader", {})
29
- inf_config = reader_config.get("INF_PROVIDERS", {})
30
-
31
- api_key = inf_config.get("api_key")
32
- if not api_key:
33
- raise ValueError("INF_PROVIDERS api_key not found in configuration")
34
-
35
- provider = inf_config.get("provider")
36
- if not provider:
37
- raise ValueError("INF_PROVIDERS provider not found in configuration")
38
-
39
- return InferenceClient(
40
- provider=provider,
41
- api_key=api_key,
42
- bill_to="GIZ",
43
- )
44
-
45
-
46
  def _create_nvidia_client():
47
  """Create NVIDIA client."""
48
  reader_config = _config.get("reader", {})
@@ -138,31 +114,31 @@ class BaseLLMAdapter(ABC):
138
  pass
139
 
140
 
141
- class MistralAdapter(BaseLLMAdapter):
142
- """Adapter for Mistral AI models."""
143
 
144
- def __init__(self, config: Dict[str, Any]):
145
- super().__init__(config)
146
- self.model = ChatMistralAI(
147
- model=config.get("model", "mistral-medium-latest")
148
- )
149
 
150
- def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
151
- """Generate response using Mistral."""
152
- response = self.model.invoke(messages)
153
 
154
- return LLMResponse(
155
- content=response.content,
156
- model=self.config.get("model", "mistral-medium-latest"),
157
- provider="mistral",
158
- metadata={"usage": getattr(response, 'usage_metadata', {})}
159
- )
160
 
161
- def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
162
- """Generate streaming response using Mistral."""
163
- for chunk in self.model.stream(messages):
164
- if chunk.content:
165
- yield chunk.content
166
 
167
 
168
  class OpenAIAdapter(BaseLLMAdapter):
@@ -192,34 +168,34 @@ class OpenAIAdapter(BaseLLMAdapter):
192
  yield chunk.content
193
 
194
 
195
- class OllamaAdapter(BaseLLMAdapter):
196
- """Adapter for Ollama models."""
197
 
198
- def __init__(self, config: Dict[str, Any]):
199
- super().__init__(config)
200
- self.model = ChatOllama(
201
- model=config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
202
- base_url=config.get("base_url", "http://localhost:11434/"),
203
- temperature=config.get("temperature", 0.8),
204
- num_predict=config.get("num_predict", 256)
205
- )
206
 
207
- def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
208
- """Generate response using Ollama."""
209
- response = self.model.invoke(messages)
210
 
211
- return LLMResponse(
212
- content=response.content,
213
- model=self.config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
214
- provider="ollama",
215
- metadata={}
216
- )
217
-
218
- def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
219
- """Generate streaming response using Ollama."""
220
- for chunk in self.model.stream(messages):
221
- if chunk.content:
222
- yield chunk.content
223
 
224
 
225
  class OpenRouterAdapter(BaseLLMAdapter):
 
1
  """LLM client adapters for different providers."""
 
 
 
2
  from dataclasses import dataclass
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, Any, List, Optional, Union
5
 
6
+ # from langchain_ollama import ChatOllama
 
7
  from langchain_openai.chat_models import ChatOpenAI
8
+ # from langchain_mistralai.chat_models import ChatMistralAI
9
+
10
 
11
+ # Legacy dependencies
12
  from huggingface_hub import InferenceClient
 
13
  from langchain_community.llms import HuggingFaceEndpoint
14
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
15
  from langchain_community.chat_models.huggingface import ChatHuggingFace
16
 
 
 
17
 
18
+ from ..config.loader import load_config
19
  _config = load_config()
20
 
21
+ # Legacy functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def _create_nvidia_client():
23
  """Create NVIDIA client."""
24
  reader_config = _config.get("reader", {})
 
114
  pass
115
 
116
 
117
+ # class MistralAdapter(BaseLLMAdapter):
118
+ # """Adapter for Mistral AI models."""
119
 
120
+ # def __init__(self, config: Dict[str, Any]):
121
+ # super().__init__(config)
122
+ # self.model = ChatMistralAI(
123
+ # model=config.get("model", "mistral-medium-latest")
124
+ # )
125
 
126
+ # def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
127
+ # """Generate response using Mistral."""
128
+ # response = self.model.invoke(messages)
129
 
130
+ # return LLMResponse(
131
+ # content=response.content,
132
+ # model=self.config.get("model", "mistral-medium-latest"),
133
+ # provider="mistral",
134
+ # metadata={"usage": getattr(response, 'usage_metadata', {})}
135
+ # )
136
 
137
+ # def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
138
+ # """Generate streaming response using Mistral."""
139
+ # for chunk in self.model.stream(messages):
140
+ # if chunk.content:
141
+ # yield chunk.content
142
 
143
 
144
  class OpenAIAdapter(BaseLLMAdapter):
 
168
  yield chunk.content
169
 
170
 
171
+ # class OllamaAdapter(BaseLLMAdapter):
172
+ # """Adapter for Ollama models."""
173
 
174
+ # def __init__(self, config: Dict[str, Any]):
175
+ # super().__init__(config)
176
+ # self.model = ChatOllama(
177
+ # model=config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
178
+ # base_url=config.get("base_url", "http://localhost:11434/"),
179
+ # temperature=config.get("temperature", 0.8),
180
+ # num_predict=config.get("num_predict", 256)
181
+ # )
182
 
183
+ # def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
184
+ # """Generate response using Ollama."""
185
+ # response = self.model.invoke(messages)
186
 
187
+ # return LLMResponse(
188
+ # content=response.content,
189
+ # model=self.config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
190
+ # provider="ollama",
191
+ # metadata={}
192
+ # )
193
+
194
+ # def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
195
+ # """Generate streaming response using Ollama."""
196
+ # for chunk in self.model.stream(messages):
197
+ # if chunk.content:
198
+ # yield chunk.content
199
 
200
 
201
  class OpenRouterAdapter(BaseLLMAdapter):