Spaces:
Sleeping
Sleeping
Commit ·
2102b2f
1
Parent(s): ce2f7c2
fix
Browse files
app.py
CHANGED
|
@@ -15,16 +15,9 @@ from transformers import (
|
|
| 15 |
import torch.nn.functional as F
|
| 16 |
from collections import defaultdict
|
| 17 |
HF_TOKEN = os.getenv("HF_token")
|
| 18 |
-
CHAT_MODEL_ID = "
|
| 19 |
EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
|
| 20 |
MAX_PROMPT_TOKENS = 8192
|
| 21 |
-
import transformers
|
| 22 |
-
pipeline = transformers.pipeline(
|
| 23 |
-
"text-generation",
|
| 24 |
-
model=CHAT_MODEL_ID,
|
| 25 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
| 26 |
-
device_map="auto",
|
| 27 |
-
)
|
| 28 |
|
| 29 |
# ---------- new defaults & helper ------------------
|
| 30 |
DEFAULT_TEMP = 0.7
|
|
@@ -125,11 +118,16 @@ def build_llm_prompt(system: str, context: list[str], user_question: str) -> str
|
|
| 125 |
# 套用 LLaMA-style prompt 格式
|
| 126 |
input_token = tokenizer.apply_chat_template(
|
| 127 |
conversation,
|
| 128 |
-
|
| 129 |
-
|
| 130 |
)
|
| 131 |
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
# ---------- 4. Gradio playground (same UI as before) --------------------------
|
| 135 |
def store_doc(doc_text: str,user_id="demo",chunk_size=DEFAULT_CHUNK_SIZE,chunk_overlap=DEFAULT_CHUNK_OVERLAP):
|
|
@@ -173,25 +171,13 @@ def answer(system: str, context: str, question: str,
|
|
| 173 |
context_list += store["texts"]
|
| 174 |
|
| 175 |
# 2. Build a Qwen-chat prompt (helper defined earlier)
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
# 3. Tokenise & cap
|
| 179 |
-
load_chat()
|
| 180 |
-
tokens = tokenizer(
|
| 181 |
-
prompt,
|
| 182 |
-
return_tensors="pt",
|
| 183 |
-
add_special_tokens=False, # we built the chat template ourselves
|
| 184 |
-
)
|
| 185 |
-
|
| 186 |
-
if tokens["input_ids"].size(1) > MAX_PROMPT_TOKENS:
|
| 187 |
-
tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
|
| 188 |
-
|
| 189 |
-
tokens = {k: v.to(chat_model.device) for k, v in tokens.items()}
|
| 190 |
|
| 191 |
# --- generate ------------------------------------------------------
|
| 192 |
output = chat_model.generate(
|
| 193 |
-
|
| 194 |
max_new_tokens=512,
|
|
|
|
| 195 |
max_length=MAX_PROMPT_TOKENS + 512,
|
| 196 |
do_sample=True,
|
| 197 |
temperature=temperature,
|
|
|
|
| 15 |
import torch.nn.functional as F
|
| 16 |
from collections import defaultdict
|
| 17 |
HF_TOKEN = os.getenv("HF_token")
|
| 18 |
+
CHAT_MODEL_ID = "NousResearch/Meta-Llama-3-8B-Instruct"
|
| 19 |
EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
|
| 20 |
MAX_PROMPT_TOKENS = 8192
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# ---------- new defaults & helper ------------------
|
| 23 |
DEFAULT_TEMP = 0.7
|
|
|
|
| 118 |
# 套用 LLaMA-style prompt 格式
|
| 119 |
input_token = tokenizer.apply_chat_template(
|
| 120 |
conversation,
|
| 121 |
+
add_generation_prompt=True,
|
| 122 |
+
return_tensors="pt"
|
| 123 |
)
|
| 124 |
|
| 125 |
+
terminators = [
|
| 126 |
+
tokenizer.eos_token_id,
|
| 127 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
return input_token, terminators
|
| 131 |
|
| 132 |
# ---------- 4. Gradio playground (same UI as before) --------------------------
|
| 133 |
def store_doc(doc_text: str,user_id="demo",chunk_size=DEFAULT_CHUNK_SIZE,chunk_overlap=DEFAULT_CHUNK_OVERLAP):
|
|
|
|
| 171 |
context_list += store["texts"]
|
| 172 |
|
| 173 |
# 2. Build a Qwen-chat prompt (helper defined earlier)
|
| 174 |
+
input_ids, terminators = build_llm_prompt(system, context_list, question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
# --- generate ------------------------------------------------------
|
| 177 |
output = chat_model.generate(
|
| 178 |
+
input_ids,
|
| 179 |
max_new_tokens=512,
|
| 180 |
+
eos_token_id=terminators,
|
| 181 |
max_length=MAX_PROMPT_TOKENS + 512,
|
| 182 |
do_sample=True,
|
| 183 |
temperature=temperature,
|