alx-d commited on
Commit
8edd557
·
verified ·
1 Parent(s): 0ded8c7

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. advanced_rag.py +16 -6
advanced_rag.py CHANGED
@@ -25,6 +25,8 @@ import gradio as gr
25
  import requests
26
  from pydantic import PrivateAttr
27
 
 
 
28
  # Add Mistral imports with fallback handling
29
  try:
30
  from mistralai import Mistral
@@ -142,14 +144,15 @@ class ElevatedRagChain:
142
  hf_api_token = os.environ.get("HF_API_TOKEN")
143
  if not hf_api_token:
144
  raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
145
- client = InferenceClient(token=hf_api_token, timeout=180)
146
  def remote_generate(prompt: str) -> str:
147
  response = client.text_generation(
148
  prompt,
149
  model=repo_id,
150
  temperature=self.temperature,
151
  top_p=self.top_p,
152
- repetition_penalty=1.1
 
153
  )
154
  return response
155
  from langchain.llms.base import LLM
@@ -172,20 +175,25 @@ class ElevatedRagChain:
172
  if not MISTRAL_AVAILABLE:
173
  raise ImportError("Mistral client library not installed. Install with: pip install mistralai")
174
  from langchain.llms.base import LLM
 
175
  class MistralLLM(LLM):
176
  temperature: float = 0.7
177
  top_p: float = 0.95
178
- _client: Any = PrivateAttr() # Declare _client as a private attribute
179
- def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95):
180
- super().__init__()
 
 
181
  self._client = Mistral(api_key=api_key)
182
  self.temperature = temperature
183
  self.top_p = top_p
 
184
  @property
185
  def _llm_type(self) -> str:
186
  return "mistral_llm"
 
187
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
188
- response = self._client.chat.complete(
189
  model="mistral-small-latest",
190
  messages=[{"role": "user", "content": prompt}],
191
  temperature=self.temperature,
@@ -193,9 +201,11 @@ class ElevatedRagChain:
193
  max_tokens=32000
194
  )
195
  return response.choices[0].message.content
 
196
  @property
197
  def _identifying_params(self) -> dict:
198
  return {"model": "mistral-small-latest"}
 
199
  mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
200
  debug_print("Mistral API pipeline created successfully.")
201
  return mistral_llm
 
25
  import requests
26
  from pydantic import PrivateAttr
27
 
28
+ print("Pydantic Version: ")
29
+ print(pydantic.__version__)
30
  # Add Mistral imports with fallback handling
31
  try:
32
  from mistralai import Mistral
 
144
  hf_api_token = os.environ.get("HF_API_TOKEN")
145
  if not hf_api_token:
146
  raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
147
+ client = InferenceClient(token=hf_api_token, timeout=240)
148
  def remote_generate(prompt: str) -> str:
149
  response = client.text_generation(
150
  prompt,
151
  model=repo_id,
152
  temperature=self.temperature,
153
  top_p=self.top_p,
154
+ repetition_penalty=1.1,
155
+ wait_for_model=True,
156
  )
157
  return response
158
  from langchain.llms.base import LLM
 
175
  if not MISTRAL_AVAILABLE:
176
  raise ImportError("Mistral client library not installed. Install with: pip install mistralai")
177
  from langchain.llms.base import LLM
178
+ from typing import Any, Optional, List
179
  class MistralLLM(LLM):
180
  temperature: float = 0.7
181
  top_p: float = 0.95
182
+ _client: Any = PrivateAttr(default=None) # Set default to None
183
+
184
+ def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
185
+ # Do not pass api_key to super().__init__ since it's not a field
186
+ super().__init__(**kwargs)
187
  self._client = Mistral(api_key=api_key)
188
  self.temperature = temperature
189
  self.top_p = top_p
190
+
191
  @property
192
  def _llm_type(self) -> str:
193
  return "mistral_llm"
194
+
195
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
196
+ response = self._client.chat.complete(
197
  model="mistral-small-latest",
198
  messages=[{"role": "user", "content": prompt}],
199
  temperature=self.temperature,
 
201
  max_tokens=32000
202
  )
203
  return response.choices[0].message.content
204
+
205
  @property
206
  def _identifying_params(self) -> dict:
207
  return {"model": "mistral-small-latest"}
208
+
209
  mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
210
  debug_print("Mistral API pipeline created successfully.")
211
  return mistral_llm