Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
|
|
| 1 |
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
import random
|
| 5 |
from textwrap import wrap
|
| 6 |
-
import spaces
|
| 7 |
|
| 8 |
def wrap_text(text, width=90):
|
| 9 |
lines = text.split('\n')
|
|
@@ -51,7 +51,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id = model_id, trust_remote_code
|
|
| 51 |
# Specify the configuration class for the model
|
| 52 |
#model_config = AutoConfig.from_pretrained(base_model_id)
|
| 53 |
|
| 54 |
-
model =
|
| 55 |
|
| 56 |
class ChatBot:
|
| 57 |
def __init__(self):
|
|
@@ -64,7 +64,7 @@ class ChatBot:
|
|
| 64 |
|
| 65 |
def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
|
| 66 |
# Combine the user's input with the system prompt
|
| 67 |
-
formatted_input = f"<s>[INST]{
|
| 68 |
|
| 69 |
# Encode the formatted input using the tokenizer
|
| 70 |
user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
|
| 3 |
import torch
|
| 4 |
import gradio as gr
|
| 5 |
import random
|
| 6 |
from textwrap import wrap
|
|
|
|
| 7 |
|
| 8 |
def wrap_text(text, width=90):
|
| 9 |
lines = text.split('\n')
|
|
|
|
| 51 |
# Specify the configuration class for the model
|
| 52 |
#model_config = AutoConfig.from_pretrained(base_model_id)
|
| 53 |
|
| 54 |
+
model = AutoModelForCausalLM.from_pretrained(model_id , torch_dtype=torch.float16 , device_map= "auto" )
|
| 55 |
|
| 56 |
class ChatBot:
|
| 57 |
def __init__(self):
|
|
|
|
| 64 |
|
| 65 |
def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
|
| 66 |
# Combine the user's input with the system prompt
|
| 67 |
+
formatted_input = f"<s> [INST] {example_instruction} [/INST] {example_answer}</s> [INST] {system_prompt} [/INST]"
|
| 68 |
|
| 69 |
# Encode the formatted input using the tokenizer
|
| 70 |
user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
|