Jia0603 commited on
Commit
762089e
·
verified ·
1 Parent(s): 7d37afc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -14
app.py CHANGED
@@ -4,12 +4,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  import gradio as gr
5
 
6
  MODEL_ID = "LMSeed/GPT2-small-distilled-100M"
7
- # HF_TOKEN = os.environ.get("HF_TOKEN") # 如果模型私有
8
 
9
  device = 0 if torch.cuda.is_available() else -1
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
 
13
  if torch.cuda.is_available():
14
  model = model.to("cuda")
15
 
@@ -17,30 +17,49 @@ generator = pipeline(
17
  "text-generation",
18
  model=model,
19
  tokenizer=tokenizer,
20
- device=0 if torch.cuda.is_available() else -1
21
  )
22
 
23
  def chat_with_model(user_message, chat_history, max_new_tokens=60, temperature=0.8, top_p=0.9):
 
 
 
 
 
24
  history_text = ""
25
- if chat_history:
26
- for i, msg in enumerate(chat_history):
27
- role = "User" if i % 2 == 0 else "Assistant"
28
- history_text += f"{role}: {msg}\n"
29
  history_text += f"User: {user_message}\nAssistant: "
30
 
31
- outputs = generator(history_text, max_new_tokens=max_new_tokens,
32
- do_sample=True, temperature=float(temperature),
33
- top_p=float(top_p), num_return_sequences=1)
 
 
 
 
 
 
 
34
  reply = outputs[0]["generated_text"][len(history_text):].strip()
 
 
35
  if "\n" in reply:
36
  reply = reply.split("\n")[0].strip()
37
- chat_history = chat_history or []
 
38
  chat_history.append(user_message)
39
  chat_history.append(reply)
40
- return "", chat_history
 
 
41
 
42
  with gr.Blocks() as demo:
43
- gr.Markdown("# Chat with GPT-2")
 
 
44
  with gr.Row():
45
  with gr.Column(scale=3):
46
  chat = gr.Chatbot(elem_id="chatbot", label="Conversation")
@@ -49,10 +68,16 @@ with gr.Blocks() as demo:
49
  max_tokens = gr.Slider(10, 256, value=60, label="max_new_tokens")
50
  temp = gr.Slider(0.1, 1.2, value=0.8, label="temperature")
51
  top_p = gr.Slider(0.1, 1.0, value=0.9, label="top_p")
 
52
  with gr.Column(scale=1):
53
  gr.Markdown("Model: " + MODEL_ID)
54
 
55
  state = gr.State([])
56
- send.click(fn=chat_with_model, inputs=[msg, state, max_tokens, temp, top_p],
57
- outputs=[msg, state])
 
 
 
 
 
58
  demo.launch()
 
4
  import gradio as gr
5
 
6
  MODEL_ID = "LMSeed/GPT2-small-distilled-100M"
 
7
 
8
  device = 0 if torch.cuda.is_available() else -1
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
11
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
12
+
13
  if torch.cuda.is_available():
14
  model = model.to("cuda")
15
 
 
17
  "text-generation",
18
  model=model,
19
  tokenizer=tokenizer,
20
+ device=device
21
  )
22
 
23
  def chat_with_model(user_message, chat_history, max_new_tokens=60, temperature=0.8, top_p=0.9):
24
+
25
+ if chat_history is None:
26
+ chat_history = []
27
+
28
+ # Build conversation context
29
  history_text = ""
30
+ for i, msg in enumerate(chat_history):
31
+ role = "User" if i % 2 == 0 else "Assistant"
32
+ history_text += f"{role}: {msg}\n"
33
+
34
  history_text += f"User: {user_message}\nAssistant: "
35
 
36
+ # Generate
37
+ outputs = generator(
38
+ history_text,
39
+ max_new_tokens=int(max_new_tokens),
40
+ do_sample=True,
41
+ temperature=float(temperature),
42
+ top_p=float(top_p),
43
+ num_return_sequences=1
44
+ )
45
+
46
  reply = outputs[0]["generated_text"][len(history_text):].strip()
47
+
48
+ # Prevent model from continuing system formatting
49
  if "\n" in reply:
50
  reply = reply.split("\n")[0].strip()
51
+
52
+ # Update chat history
53
  chat_history.append(user_message)
54
  chat_history.append(reply)
55
+
56
+ return "", chat_history, chat_history
57
+
58
 
59
  with gr.Blocks() as demo:
60
+
61
+ gr.Markdown("# Chat with Stu")
62
+
63
  with gr.Row():
64
  with gr.Column(scale=3):
65
  chat = gr.Chatbot(elem_id="chatbot", label="Conversation")
 
68
  max_tokens = gr.Slider(10, 256, value=60, label="max_new_tokens")
69
  temp = gr.Slider(0.1, 1.2, value=0.8, label="temperature")
70
  top_p = gr.Slider(0.1, 1.0, value=0.9, label="top_p")
71
+
72
  with gr.Column(scale=1):
73
  gr.Markdown("Model: " + MODEL_ID)
74
 
75
  state = gr.State([])
76
+
77
+ send.click(
78
+ fn=chat_with_model,
79
+ inputs=[msg, state, max_tokens, temp, top_p],
80
+ outputs=[msg, chat, state]
81
+ )
82
+
83
  demo.launch()