Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,7 +19,8 @@ device = "cuda:0"
|
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
|
| 20 |
config = PeftConfig.from_pretrained(peft_model_id)
|
| 21 |
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,
|
| 22 |
-
device_map={"": "cuda:0"},
|
|
|
|
| 23 |
|
| 24 |
uses_transformers_4_46 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.46.0")
|
| 25 |
print(f"PAQUETE DE TRANSFORMERS: {uses_transformers_4_46}")
|
|
@@ -49,7 +50,7 @@ def generate_response(msg: str, history: list[list[str, str]], system_prompt: st
|
|
| 49 |
chat_history = format_history(msg, history, system_prompt)
|
| 50 |
encodeds = tokenizer.apply_chat_template(chat_history, return_tensors="pt", add_generation_prompt=True)
|
| 51 |
model_inputs = encodeds.to("cuda")
|
| 52 |
-
generated_ids = model.generate(model_inputs, repetition_penalty=rep_pen, max_new_tokens=
|
| 53 |
response = tokenizer.batch_decode(generated_ids,skip_special_tokens=True)[0]
|
| 54 |
if len(response)>0:
|
| 55 |
message=response[response.rfind("assistant\n") + len("assistant\n"):]
|
|
|
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
|
| 20 |
config = PeftConfig.from_pretrained(peft_model_id)
|
| 21 |
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,
|
| 22 |
+
device_map={"": "cuda:0"},
|
| 23 |
+
quantization_config=bnb_config) #offload_state_dict=False
|
| 24 |
|
| 25 |
uses_transformers_4_46 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.46.0")
|
| 26 |
print(f"PAQUETE DE TRANSFORMERS: {uses_transformers_4_46}")
|
|
|
|
| 50 |
chat_history = format_history(msg, history, system_prompt)
|
| 51 |
encodeds = tokenizer.apply_chat_template(chat_history, return_tensors="pt", add_generation_prompt=True)
|
| 52 |
model_inputs = encodeds.to("cuda")
|
| 53 |
+
generated_ids = model.generate(model_inputs, repetition_penalty=rep_pen, max_new_tokens=248, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, eos_token_id=tokenizer.eos_token_id)
|
| 54 |
response = tokenizer.batch_decode(generated_ids,skip_special_tokens=True)[0]
|
| 55 |
if len(response)>0:
|
| 56 |
message=response[response.rfind("assistant\n") + len("assistant\n"):]
|