Jia0603 commited on
Commit
fb0bc70
·
verified ·
1 Parent(s): 07d75ef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ import gradio as gr
5
+
6
+ MODEL_ID = "LMSeed/GPT2-small-distilled-100M"
7
+ # HF_TOKEN = os.environ.get("HF_TOKEN") # 如果模型私有
8
+
9
+ device = 0 if torch.cuda.is_available() else -1
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
12
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
13
+ if torch.cuda.is_available():
14
+ model = model.to("cuda")
15
+
16
+ generator = pipeline(
17
+ "text-generation",
18
+ model=model,
19
+ tokenizer=tokenizer,
20
+ device=0 if torch.cuda.is_available() else -1
21
+ )
22
+
23
+ def chat_with_model(user_message, chat_history, max_new_tokens=60, temperature=0.8, top_p=0.9):
24
+ history_text = ""
25
+ if chat_history:
26
+ for i, msg in enumerate(chat_history):
27
+ role = "User" if i % 2 == 0 else "Assistant"
28
+ history_text += f"{role}: {msg}\n"
29
+ history_text += f"User: {user_message}\nAssistant: "
30
+
31
+ outputs = generator(history_text, max_new_tokens=max_new_tokens,
32
+ do_sample=True, temperature=float(temperature),
33
+ top_p=float(top_p), num_return_sequences=1)
34
+ reply = outputs[0]["generated_text"][len(history_text):].strip()
35
+ if "\n" in reply:
36
+ reply = reply.split("\n")[0].strip()
37
+ chat_history = chat_history or []
38
+ chat_history.append(user_message)
39
+ chat_history.append(reply)
40
+ return "", chat_history
41
+
42
+ with gr.Blocks() as demo:
43
+ gr.Markdown("# Chat with GPT-2")
44
+ with gr.Row():
45
+ with gr.Column(scale=3):
46
+ chat = gr.Chatbot(elem_id="chatbot", label="Conversation")
47
+ msg = gr.Textbox(label="Your message")
48
+ send = gr.Button("Send")
49
+ max_tokens = gr.Slider(10, 256, value=60, label="max_new_tokens")
50
+ temp = gr.Slider(0.1, 1.2, value=0.8, label="temperature")
51
+ top_p = gr.Slider(0.1, 1.0, value=0.9, label="top_p")
52
+ with gr.Column(scale=1):
53
+ gr.Markdown("Model: " + MODEL_ID)
54
+
55
+ state = gr.State([])
56
+ send.click(fn=chat_with_model, inputs=[msg, state, max_tokens, temp, top_p],
57
+ outputs=[msg, state])
58
+ demo.launch()