Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -29,7 +29,7 @@ if not torch.cuda.is_available():
|
|
| 29 |
if torch.cuda.is_available():
|
| 30 |
model_id = "mistral-community/Mixtral-8x22B-v0.1-4bit"
|
| 31 |
model = AutoModelForCausalLM.from_pretrained(model_id,quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True),
|
| 32 |
-
device_map="
|
| 33 |
# torch_dtype=torch.float16,
|
| 34 |
# load_in_8bit=True,
|
| 35 |
trust_remote_code=True)
|
|
@@ -55,7 +55,7 @@ def generate(
|
|
| 55 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
| 56 |
conversation.append({"role": "user", "content": message})
|
| 57 |
|
| 58 |
-
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
|
| 59 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
| 60 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
| 61 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
|
|
|
| 29 |
if torch.cuda.is_available():
|
| 30 |
model_id = "mistral-community/Mixtral-8x22B-v0.1-4bit"
|
| 31 |
model = AutoModelForCausalLM.from_pretrained(model_id,quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True),
|
| 32 |
+
device_map="cuda",
|
| 33 |
# torch_dtype=torch.float16,
|
| 34 |
# load_in_8bit=True,
|
| 35 |
trust_remote_code=True)
|
|
|
|
| 55 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
| 56 |
conversation.append({"role": "user", "content": message})
|
| 57 |
|
| 58 |
+
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
|
| 59 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
| 60 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
| 61 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|