Upload folder using huggingface_hub
Browse files- advanced_rag.py +16 -6
advanced_rag.py
CHANGED
|
@@ -25,6 +25,8 @@ import gradio as gr
|
|
| 25 |
import requests
|
| 26 |
from pydantic import PrivateAttr
|
| 27 |
|
|
|
|
|
|
|
| 28 |
# Add Mistral imports with fallback handling
|
| 29 |
try:
|
| 30 |
from mistralai import Mistral
|
|
@@ -142,14 +144,15 @@ class ElevatedRagChain:
|
|
| 142 |
hf_api_token = os.environ.get("HF_API_TOKEN")
|
| 143 |
if not hf_api_token:
|
| 144 |
raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
|
| 145 |
-
client = InferenceClient(token=hf_api_token, timeout=
|
| 146 |
def remote_generate(prompt: str) -> str:
|
| 147 |
response = client.text_generation(
|
| 148 |
prompt,
|
| 149 |
model=repo_id,
|
| 150 |
temperature=self.temperature,
|
| 151 |
top_p=self.top_p,
|
| 152 |
-
repetition_penalty=1.1
|
|
|
|
| 153 |
)
|
| 154 |
return response
|
| 155 |
from langchain.llms.base import LLM
|
|
@@ -172,20 +175,25 @@ class ElevatedRagChain:
|
|
| 172 |
if not MISTRAL_AVAILABLE:
|
| 173 |
raise ImportError("Mistral client library not installed. Install with: pip install mistralai")
|
| 174 |
from langchain.llms.base import LLM
|
|
|
|
| 175 |
class MistralLLM(LLM):
|
| 176 |
temperature: float = 0.7
|
| 177 |
top_p: float = 0.95
|
| 178 |
-
_client: Any = PrivateAttr() #
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
| 181 |
self._client = Mistral(api_key=api_key)
|
| 182 |
self.temperature = temperature
|
| 183 |
self.top_p = top_p
|
|
|
|
| 184 |
@property
|
| 185 |
def _llm_type(self) -> str:
|
| 186 |
return "mistral_llm"
|
|
|
|
| 187 |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 188 |
-
response = self._client.chat.complete(
|
| 189 |
model="mistral-small-latest",
|
| 190 |
messages=[{"role": "user", "content": prompt}],
|
| 191 |
temperature=self.temperature,
|
|
@@ -193,9 +201,11 @@ class ElevatedRagChain:
|
|
| 193 |
max_tokens=32000
|
| 194 |
)
|
| 195 |
return response.choices[0].message.content
|
|
|
|
| 196 |
@property
|
| 197 |
def _identifying_params(self) -> dict:
|
| 198 |
return {"model": "mistral-small-latest"}
|
|
|
|
| 199 |
mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
|
| 200 |
debug_print("Mistral API pipeline created successfully.")
|
| 201 |
return mistral_llm
|
|
|
|
| 25 |
import requests
|
| 26 |
from pydantic import PrivateAttr
|
| 27 |
|
| 28 |
+
print("Pydantic Version: ")
|
| 29 |
+
print(pydantic.__version__)
|
| 30 |
# Add Mistral imports with fallback handling
|
| 31 |
try:
|
| 32 |
from mistralai import Mistral
|
|
|
|
| 144 |
hf_api_token = os.environ.get("HF_API_TOKEN")
|
| 145 |
if not hf_api_token:
|
| 146 |
raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
|
| 147 |
+
client = InferenceClient(token=hf_api_token, timeout=240)
|
| 148 |
def remote_generate(prompt: str) -> str:
|
| 149 |
response = client.text_generation(
|
| 150 |
prompt,
|
| 151 |
model=repo_id,
|
| 152 |
temperature=self.temperature,
|
| 153 |
top_p=self.top_p,
|
| 154 |
+
repetition_penalty=1.1,
|
| 155 |
+
wait_for_model=True,
|
| 156 |
)
|
| 157 |
return response
|
| 158 |
from langchain.llms.base import LLM
|
|
|
|
| 175 |
if not MISTRAL_AVAILABLE:
|
| 176 |
raise ImportError("Mistral client library not installed. Install with: pip install mistralai")
|
| 177 |
from langchain.llms.base import LLM
|
| 178 |
+
from typing import Any, Optional, List
|
| 179 |
class MistralLLM(LLM):
|
| 180 |
temperature: float = 0.7
|
| 181 |
top_p: float = 0.95
|
| 182 |
+
_client: Any = PrivateAttr(default=None) # Set default to None
|
| 183 |
+
|
| 184 |
+
def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
|
| 185 |
+
# Do not pass api_key to super().__init__ since it's not a field
|
| 186 |
+
super().__init__(**kwargs)
|
| 187 |
self._client = Mistral(api_key=api_key)
|
| 188 |
self.temperature = temperature
|
| 189 |
self.top_p = top_p
|
| 190 |
+
|
| 191 |
@property
|
| 192 |
def _llm_type(self) -> str:
|
| 193 |
return "mistral_llm"
|
| 194 |
+
|
| 195 |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 196 |
+
response = self._client.chat.complete(
|
| 197 |
model="mistral-small-latest",
|
| 198 |
messages=[{"role": "user", "content": prompt}],
|
| 199 |
temperature=self.temperature,
|
|
|
|
| 201 |
max_tokens=32000
|
| 202 |
)
|
| 203 |
return response.choices[0].message.content
|
| 204 |
+
|
| 205 |
@property
|
| 206 |
def _identifying_params(self) -> dict:
|
| 207 |
return {"model": "mistral-small-latest"}
|
| 208 |
+
|
| 209 |
mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
|
| 210 |
debug_print("Mistral API pipeline created successfully.")
|
| 211 |
return mistral_llm
|