drumwell commited on
Commit
af78cd0
·
verified ·
1 Parent(s): d468114

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -13
app.py CHANGED
@@ -1,17 +1,21 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
3
  from threading import Thread
4
  import torch
5
 
6
- # Load your model once at startup
7
- print("Loading model...")
8
- tokenizer = AutoTokenizer.from_pretrained("drumwell/autotrain-2duhi-5mmyz")
9
  model = AutoModelForCausalLM.from_pretrained(
10
- "drumwell/autotrain-2duhi-5mmyz",
11
  device_map="auto",
12
  torch_dtype=torch.float16,
13
  load_in_8bit=True,
14
  )
 
 
 
15
  print("Model loaded!")
16
 
17
  def respond(
@@ -23,18 +27,13 @@ def respond(
23
  top_p,
24
  hf_token: gr.OAuthToken,
25
  ):
26
- """
27
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
28
- """
29
  messages = [{"role": "system", "content": system_message}]
30
  messages.extend(history)
31
  messages.append({"role": "user", "content": message})
32
 
33
- # Apply chat template
34
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
35
  inputs = tokenizer(text, return_tensors="pt").to(model.device)
36
 
37
- # Setup streaming
38
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
39
  generation_kwargs = dict(
40
  inputs,
@@ -46,7 +45,6 @@ def respond(
46
  repetition_penalty=1.1,
47
  )
48
 
49
- # Generate in a separate thread for streaming
50
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
51
  thread.start()
52
 
@@ -55,9 +53,6 @@ def respond(
55
  response += token
56
  yield response
57
 
58
- """
59
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
60
- """
61
  chatbot = gr.ChatInterface(
62
  respond,
63
  type="messages",
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
+ from peft import PeftModel
4
  from threading import Thread
5
  import torch
6
 
7
+ # Load base model + your adapter
8
+ print("Loading base model...")
9
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
10
  model = AutoModelForCausalLM.from_pretrained(
11
+ "meta-llama/Llama-3.1-8B-Instruct",
12
  device_map="auto",
13
  torch_dtype=torch.float16,
14
  load_in_8bit=True,
15
  )
16
+
17
+ print("Loading your fine-tuned adapter...")
18
+ model = PeftModel.from_pretrained(model, "drumwell/autotrain-2duhi-5mmyz")
19
  print("Model loaded!")
20
 
21
  def respond(
 
27
  top_p,
28
  hf_token: gr.OAuthToken,
29
  ):
 
 
 
30
  messages = [{"role": "system", "content": system_message}]
31
  messages.extend(history)
32
  messages.append({"role": "user", "content": message})
33
 
 
34
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
35
  inputs = tokenizer(text, return_tensors="pt").to(model.device)
36
 
 
37
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
38
  generation_kwargs = dict(
39
  inputs,
 
45
  repetition_penalty=1.1,
46
  )
47
 
 
48
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
49
  thread.start()
50
 
 
53
  response += token
54
  yield response
55
 
 
 
 
56
  chatbot = gr.ChatInterface(
57
  respond,
58
  type="messages",