Spaces:
Running
Running
add different llms
Browse files- app.py +7 -7
- backend/query_llm.py +13 -5
app.py
CHANGED
|
@@ -63,25 +63,25 @@ def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encode
|
|
| 63 |
prompt_html = template_html.render(documents=documents, query=query)
|
| 64 |
|
| 65 |
if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
|
| 66 |
-
|
| 67 |
if llm_model == "mistralai/Mistral-7B-v0.1":
|
| 68 |
-
|
| 69 |
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
|
| 70 |
-
|
| 71 |
if llm_model == "gpt-3.5-turbo":
|
| 72 |
-
|
| 73 |
if llm_model == "gpt-4-turbo-preview":
|
| 74 |
-
|
| 75 |
|
| 76 |
#if api_kind == "HuggingFace":
|
| 77 |
# generate_fn = generate_hf
|
| 78 |
#elif api_kind == "OpenAI":
|
| 79 |
# generate_fn = generate_openai
|
| 80 |
#else:
|
| 81 |
-
|
| 82 |
|
| 83 |
history[-1][1] = ""
|
| 84 |
-
for character in generate_fn(prompt, history[:-1]):
|
| 85 |
history[-1][1] = character
|
| 86 |
yield history, prompt_html
|
| 87 |
|
|
|
|
| 63 |
prompt_html = template_html.render(documents=documents, query=query)
|
| 64 |
|
| 65 |
if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
|
| 66 |
+
generate_fn = generate_hf
|
| 67 |
if llm_model == "mistralai/Mistral-7B-v0.1":
|
| 68 |
+
generate_fn = generate_hf
|
| 69 |
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
|
| 70 |
+
generate_fn = generate_hf
|
| 71 |
if llm_model == "gpt-3.5-turbo":
|
| 72 |
+
generate_fn = generate_openai
|
| 73 |
if llm_model == "gpt-4-turbo-preview":
|
| 74 |
+
generate_fn = generate_openai
|
| 75 |
|
| 76 |
#if api_kind == "HuggingFace":
|
| 77 |
# generate_fn = generate_hf
|
| 78 |
#elif api_kind == "OpenAI":
|
| 79 |
# generate_fn = generate_openai
|
| 80 |
#else:
|
| 81 |
+
raise gr.Error(f"API {api_kind} is not supported")
|
| 82 |
|
| 83 |
history[-1][1] = ""
|
| 84 |
+
for character in generate_fn(prompt, history[:-1], llm_model):
|
| 85 |
history[-1][1] = character
|
| 86 |
yield history, prompt_html
|
| 87 |
|
backend/query_llm.py
CHANGED
|
@@ -34,7 +34,7 @@ OAI_GENERATE_KWARGS = {
|
|
| 34 |
}
|
| 35 |
|
| 36 |
|
| 37 |
-
def format_prompt(message: str, api_kind: str):
|
| 38 |
"""
|
| 39 |
Formats the given message using a chat template.
|
| 40 |
|
|
@@ -51,12 +51,13 @@ def format_prompt(message: str, api_kind: str):
|
|
| 51 |
if api_kind == "openai":
|
| 52 |
return messages
|
| 53 |
elif api_kind == "hf":
|
|
|
|
| 54 |
return TOKENIZER.apply_chat_template(messages, tokenize=False)
|
| 55 |
elif api_kind:
|
| 56 |
raise ValueError("API is not supported")
|
| 57 |
|
| 58 |
|
| 59 |
-
def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
|
| 60 |
"""
|
| 61 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
| 62 |
|
|
@@ -67,8 +68,14 @@ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
|
|
| 67 |
Generator[str, None, str]: A generator yielding chunks of generated text.
|
| 68 |
Returns a final string if an error occurs.
|
| 69 |
"""
|
|
|
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
|
| 73 |
|
| 74 |
try:
|
|
@@ -93,7 +100,7 @@ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
|
|
| 93 |
raise gr.Error(f"Unhandled Exception: {str(e)}")
|
| 94 |
|
| 95 |
|
| 96 |
-
def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
|
| 97 |
"""
|
| 98 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
| 99 |
|
|
@@ -108,7 +115,8 @@ def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
|
|
| 108 |
|
| 109 |
try:
|
| 110 |
stream = OAI_CLIENT.chat.completions.create(
|
| 111 |
-
model=os.getenv("OPENAI_MODEL"),
|
|
|
|
| 112 |
messages=formatted_prompt,
|
| 113 |
**OAI_GENERATE_KWARGS,
|
| 114 |
stream=True
|
|
|
|
| 34 |
}
|
| 35 |
|
| 36 |
|
| 37 |
+
def format_prompt(message: str, api_kind: str, tokenizer_name = None):
|
| 38 |
"""
|
| 39 |
Formats the given message using a chat template.
|
| 40 |
|
|
|
|
| 51 |
if api_kind == "openai":
|
| 52 |
return messages
|
| 53 |
elif api_kind == "hf":
|
| 54 |
+
TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 55 |
return TOKENIZER.apply_chat_template(messages, tokenize=False)
|
| 56 |
elif api_kind:
|
| 57 |
raise ValueError("API is not supported")
|
| 58 |
|
| 59 |
|
| 60 |
+
def generate_hf(prompt: str, history: str, hf_model_name: str) -> Generator[str, None, str]:
|
| 61 |
"""
|
| 62 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
| 63 |
|
|
|
|
| 68 |
Generator[str, None, str]: A generator yielding chunks of generated text.
|
| 69 |
Returns a final string if an error occurs.
|
| 70 |
"""
|
| 71 |
+
|
| 72 |
|
| 73 |
+
HF_CLIENT = InferenceClient(
|
| 74 |
+
hf_model_name,
|
| 75 |
+
token=os.getenv("HF_TOKEN")
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
formatted_prompt = format_prompt(prompt, "hf", hf_model_name)
|
| 79 |
formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
|
| 80 |
|
| 81 |
try:
|
|
|
|
| 100 |
raise gr.Error(f"Unhandled Exception: {str(e)}")
|
| 101 |
|
| 102 |
|
| 103 |
+
def generate_openai(prompt: str, history: str, model_name: str) -> Generator[str, None, str]:
|
| 104 |
"""
|
| 105 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
| 106 |
|
|
|
|
| 115 |
|
| 116 |
try:
|
| 117 |
stream = OAI_CLIENT.chat.completions.create(
|
| 118 |
+
#model=os.getenv("OPENAI_MODEL"),
|
| 119 |
+
model = model_name,
|
| 120 |
messages=formatted_prompt,
|
| 121 |
**OAI_GENERATE_KWARGS,
|
| 122 |
stream=True
|