ranamhamoud commited on
Commit
92feef3
·
verified ·
1 Parent(s): 8bcbf38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -12
app.py CHANGED
@@ -29,19 +29,26 @@ this demo is governed by the original [license](https://huggingface.co/spaces/hu
29
  if not torch.cuda.is_available():
30
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
31
 
32
- # Model and Tokenizer Configuration
33
- model_id = "meta-llama/Llama-2-7b-hf"
34
- bnb_config = BitsAndBytesConfig(
35
- load_in_4bit=True,
36
- bnb_4bit_use_double_quant=False,
37
- bnb_4bit_quant_type="nf4",
38
- bnb_4bit_compute_dtype=torch.bfloat16
39
- )
40
- base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
41
- model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
42
- tokenizer = AutoTokenizer.from_pretrained(model_id)
43
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
44
 
 
45
  # MongoDB Connection
46
  PASSWORD = os.environ.get("MONGO_PASS")
47
  connect(host=f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/")
@@ -74,6 +81,12 @@ def generate(
74
  repetition_penalty: float = 1.0,
75
  ) -> Iterator[str]:
76
  conversation = []
 
 
 
 
 
 
77
  for user, assistant in chat_history:
78
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
79
  conversation.append({"role": "user", "content": make_prompt(message)})
@@ -116,6 +129,7 @@ def generate(
116
  chat_interface = gr.ChatInterface(
117
  fn=generate,
118
  stop_btn=None,
 
119
  examples=[
120
  ["Can you explain briefly to me what is the Python programming language?"],
121
  ["Could you please provide an explanation about the concept of recursion?"],
 
29
  if not torch.cuda.is_available():
30
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
31
 
32
+ if torch.cuda.is_available():
33
+ model_id = "meta-llama/Llama-2-7b-hf"
34
+ bnb_config = BitsAndBytesConfig(
35
+ load_in_4bit=True,
36
+ bnb_4bit_use_double_quant=False,
37
+ bnb_4bit_quant_type="nf4",
38
+ bnb_4bit_compute_dtype=torch.bfloat16
39
+ )
40
+ base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
41
+ storytell_model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
42
+ storytell_tokenizer = AutoTokenizer.from_pretrained(model_id)
43
+ storytell_tokenizer.pad_token = tokenizer.eos_token
44
+
45
+
46
+ editing_model_id = "meta-llama/Llama-2-7b-chat-hf"
47
+ editing_model = AutoModelForCausalLM.from_pretrained(editing_model_id, torch_dtype=torch.float16, device_map="auto")
48
+ editing_tokenizer = AutoTokenizer.from_pretrained(model_id)
49
+ editing_tokenizer.use_default_system_prompt = False
50
 
51
+
52
  # MongoDB Connection
53
  PASSWORD = os.environ.get("MONGO_PASS")
54
  connect(host=f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/")
 
81
  repetition_penalty: float = 1.0,
82
  ) -> Iterator[str]:
83
  conversation = []
84
+ if model_choice == "Storytell":
85
+ model = storytell_model
86
+ tokenizer = storytell_tokenizer
87
+ else:
88
+ model = editing_model
89
+ tokenizer = editing_tokenizer
90
  for user, assistant in chat_history:
91
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
92
  conversation.append({"role": "user", "content": make_prompt(message)})
 
129
  chat_interface = gr.ChatInterface(
130
  fn=generate,
131
  stop_btn=None,
132
+ inputs=[model_selector = gr.Dropdown(model_choice=["Storytell", "HF Meta Llama 7b Chat"], label="Choose Model")]
133
  examples=[
134
  ["Can you explain briefly to me what is the Python programming language?"],
135
  ["Could you please provide an explanation about the concept of recursion?"],