Lepish commited on
Commit
d5f5706
·
verified ·
1 Parent(s): aefac4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -1,18 +1,24 @@
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
 
5
- # Load model and tokenizer
 
 
 
6
  model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
7
 
8
- tokenizer = AutoTokenizer.from_pretrained(model_id)
 
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_id,
 
11
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
12
  device_map="auto"
13
  )
14
 
15
- # Helper to build prompt
16
  def build_prompt(user_input, history):
17
  prompt = "You are a pirate chatbot who always responds in pirate speak!\n"
18
  for user_msg, bot_reply in history:
@@ -20,7 +26,7 @@ def build_prompt(user_input, history):
20
  prompt += f"User: {user_input}\nPirate:"
21
  return prompt
22
 
23
- # Chat function
24
  def chat(user_input, history):
25
  prompt = build_prompt(user_input, history)
26
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
@@ -29,17 +35,16 @@ def chat(user_input, history):
29
  **inputs,
30
  max_new_tokens=256,
31
  do_sample=True,
32
- top_p=0.9,
33
  temperature=0.8,
 
34
  pad_token_id=tokenizer.eos_token_id
35
  )
36
 
37
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
- # Extract only the last bot message
39
- pirate_reply = decoded.split("Pirate:")[-1].strip()
40
  return pirate_reply
41
 
42
- # Gradio UI
43
  with gr.Blocks() as demo:
44
  gr.Markdown("## 🏴‍☠️ Talk to the Pirate Bot!")
45
  chatbot = gr.Chatbot()
 
1
+ import os
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
 
6
+ # Load token from Hugging Face secret
7
+ HF_TOKEN = os.environ.get("key")
8
+
9
+ # ✅ Model ID
10
  model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
11
 
12
+ # ✅ Load tokenizer and model securely
13
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
14
  model = AutoModelForCausalLM.from_pretrained(
15
  model_id,
16
+ use_auth_token=HF_TOKEN,
17
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
18
  device_map="auto"
19
  )
20
 
21
+ # 🧠 Prompt Builder
22
  def build_prompt(user_input, history):
23
  prompt = "You are a pirate chatbot who always responds in pirate speak!\n"
24
  for user_msg, bot_reply in history:
 
26
  prompt += f"User: {user_input}\nPirate:"
27
  return prompt
28
 
29
+ # 💬 Chat Handler
30
  def chat(user_input, history):
31
  prompt = build_prompt(user_input, history)
32
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
35
  **inputs,
36
  max_new_tokens=256,
37
  do_sample=True,
 
38
  temperature=0.8,
39
+ top_p=0.9,
40
  pad_token_id=tokenizer.eos_token_id
41
  )
42
 
43
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
+ pirate_reply = response.split("Pirate:")[-1].strip()
 
45
  return pirate_reply
46
 
47
+ # 🧱 Gradio UI
48
  with gr.Blocks() as demo:
49
  gr.Markdown("## 🏴‍☠️ Talk to the Pirate Bot!")
50
  chatbot = gr.Chatbot()