nosadaniel commited on
Commit
57c3533
·
verified ·
1 Parent(s): 1957184

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -42
app.py CHANGED
@@ -1,6 +1,24 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- #from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def respond(
6
  message,
@@ -9,43 +27,45 @@ def respond(
9
  max_tokens,
10
  temperature,
11
  top_p,
12
- hf_token: gr.OAuthToken,
13
  ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- #client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
- client = InferenceClient(token=hf_token.token, model="nosadaniel/llama3-1-8b-tuned")
19
- #model="openai/gpt-oss-20b"
20
- #model="nosadaniel/llama3-1-8b-tuned"
21
-
22
- messages = [{"role": "system", "content": system_message}]
23
-
24
- messages.extend(history)
25
-
26
- messages.append({"role": "user", "content": message})
27
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  response = ""
29
-
30
- for message in client.text_generation(
31
- messages,
32
- max_new_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- choices = message.choices
38
- token = ""
39
- if len(choices) and choices[0].delta.content:
40
- token = choices[0].delta.content
41
-
42
- response += token
43
  yield response
44
 
45
-
46
- """
47
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
48
- """
49
  chatbot = gr.ChatInterface(
50
  respond,
51
  type="messages",
@@ -61,13 +81,8 @@ chatbot = gr.ChatInterface(
61
  label="Top-p (nucleus sampling)",
62
  ),
63
  ],
 
64
  )
65
 
66
- with gr.Blocks() as demo:
67
- with gr.Sidebar():
68
- gr.LoginButton()
69
- chatbot.render()
70
-
71
-
72
  if __name__ == "__main__":
73
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
+ import torch
4
+ from threading import Thread
5
+
6
+ # Load model at startup
7
+ MODEL_NAME = "nosadaniel/llama3-1-8b-tuned"
8
+ print("Loading model and tokenizer...")
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ MODEL_NAME,
13
+ torch_dtype=torch.float16,
14
+ device_map="auto",
15
+ low_cpu_mem_usage=True
16
+ )
17
+
18
+ if tokenizer.pad_token is None:
19
+ tokenizer.pad_token = tokenizer.eos_token
20
+
21
+ print("Model loaded successfully!")
22
 
23
  def respond(
24
  message,
 
27
  max_tokens,
28
  temperature,
29
  top_p,
 
30
  ):
31
+ # Build conversation
32
+ conversation = f"{system_message}\n\n"
33
+
34
+ for msg in history:
35
+ if msg["role"] == "user":
36
+ conversation += f"User: {msg['content']}\n"
37
+ elif msg["role"] == "assistant":
38
+ conversation += f"Assistant: {msg['content']}\n"
39
+
40
+ conversation += f"User: {message}\nAssistant:"
41
+
42
+ # Tokenize
43
+ inputs = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
44
+
45
+ # Setup streamer
46
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
47
+
48
+ generation_kwargs = {
49
+ **inputs,
50
+ "max_new_tokens": max_tokens,
51
+ "temperature": temperature,
52
+ "top_p": top_p,
53
+ "do_sample": True,
54
+ "pad_token_id": tokenizer.pad_token_id,
55
+ "eos_token_id": tokenizer.eos_token_id,
56
+ "streamer": streamer,
57
+ }
58
+
59
+ # Generate in separate thread
60
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
61
+ thread.start()
62
+
63
+ # Stream output
64
  response = ""
65
+ for new_text in streamer:
66
+ response += new_text
 
 
 
 
 
 
 
 
 
 
 
 
67
  yield response
68
 
 
 
 
 
69
  chatbot = gr.ChatInterface(
70
  respond,
71
  type="messages",
 
81
  label="Top-p (nucleus sampling)",
82
  ),
83
  ],
84
+ title="Llama 3.1 8B Tuned Chat"
85
  )
86
 
 
 
 
 
 
 
87
  if __name__ == "__main__":
88
+ demo.launch()