donjun commited on
Commit
80986a7
Β·
verified Β·
1 Parent(s): bf0b389

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -86
app.py CHANGED
@@ -2,94 +2,29 @@ import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import gradio as gr
4
 
5
- # 1. λͺ¨λΈ μ΄ˆκΈ°ν™”
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
- print(f"Using device: {device}")
 
8
 
9
- try:
10
- model = AutoModelForCausalLM.from_pretrained(
11
- "naver-hyperclovax/HyperCLOVAX-SEED-Text-Instruct-0.5B",
12
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
13
- ).to(device)
14
-
15
- tokenizer = AutoTokenizer.from_pretrained(
16
- "naver-hyperclovax/HyperCLOVAX-SEED-Text-Instruct-0.5B"
17
- )
18
- print("Model loaded successfully")
19
- except Exception as e:
20
- print(f"Model loading failed: {e}")
21
- raise
22
-
23
- # 2. μ±„νŒ… 생성 ν•¨μˆ˜
24
- def generate_response(chat_history, user_input):
25
- # λŒ€ν™” 기둝 μ—…λ°μ΄νŠΈ
26
- chat_history.append({"role": "user", "content": user_input})
27
-
28
- # ν…œν”Œλ¦Ώ 적용
29
- inputs = tokenizer.apply_chat_template(
30
- chat_history,
31
- add_generation_prompt=True,
32
- return_tensors="pt"
33
- ).to(device)
34
-
35
- # 응닡 생성
36
- output_ids = model.generate(
37
- inputs,
38
- max_length=1024,
39
- temperature=0.7,
40
- top_p=0.9,
41
- do_sample=True,
42
- eos_token_id=tokenizer.eos_token_id
43
- )
44
-
45
- # 응닡 λ””μ½”λ”©
46
- response = tokenizer.decode(
47
- output_ids[0][inputs.shape[1]:],
48
- skip_special_tokens=True
49
- )
50
-
51
- chat_history.append({"role": "assistant", "content": response})
52
- return response
53
 
54
- # 3. Gradio μΈν„°νŽ˜μ΄μŠ€
55
- def chat_interface(user_input, chat_history_ui):
56
- # 초기 μ‹œμŠ€ν…œ λ©”μ‹œμ§€
57
- if not chat_history_ui:
58
- chat_history = [
59
- {"role": "system", "content": "당신은 λ„€μ΄λ²„μ˜ CLOVA X AIμž…λ‹ˆλ‹€."}
60
- ]
61
- else:
62
- chat_history = []
63
- for msg in chat_history_ui:
64
- chat_history.extend([
65
- {"role": "user", "content": msg[0]},
66
- {"role": "assistant", "content": msg[1]}
67
- ])
68
-
69
- # 응닡 생성
70
- bot_response = generate_response(chat_history, user_input)
71
- chat_history_ui.append((user_input, bot_response))
72
-
73
- return "", chat_history_ui
74
-
75
- # 4. μ•± μ‹€ν–‰
76
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
77
- gr.Markdown("## πŸ§‘πŸ’» HyperCLOVAX-SEED 챗봇")
78
-
79
- chatbot = gr.Chatbot(height=500)
80
- msg = gr.Textbox(label="λ©”μ‹œμ§€ μž…λ ₯")
81
- clear = gr.Button("μ΄ˆκΈ°ν™”")
82
-
83
- msg.submit(
84
- chat_interface,
85
- [msg, chatbot],
86
- [msg, chatbot]
87
- )
88
- clear.click(lambda: None, None, chatbot, queue=False)
89
 
90
  if __name__ == "__main__":
91
- demo.launch(
92
- server_name="0.0.0.0",
93
- server_port=7860,
94
- share=False
95
- )
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import gradio as gr
4
 
5
+ # λͺ¨λΈ μ΄ˆκΈ°ν™” (κ°„μ†Œν™” 버전)
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ model = AutoModelForCausalLM.from_pretrained("naver-hyperclovax/HyperCLOVAX-SEED-Text-Instruct-0.5B").to(device)
8
+ tokenizer = AutoTokenizer.from_pretrained("naver-hyperclovax/HyperCLOVAX-SEED-Text-Instruct-0.5B")
9
 
10
+ def respond(message, history):
11
+ # λŒ€ν™” 기둝을 λͺ¨λΈ μž…λ ₯ ν˜•μ‹μœΌλ‘œ λ³€ν™˜
12
+ chat = [
13
+ {"role": "system", "content": "당신은 λ„€μ΄λ²„μ˜ CLOVA X AIμž…λ‹ˆλ‹€."},
14
+ *[{"role": "user" if h[0] == message else "assistant", "content": h[1]} for h in history],
15
+ {"role": "user", "content": message}
16
+ ]
17
+
18
+ inputs = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
19
+ outputs = model.generate(inputs, max_length=1024, temperature=0.7)
20
+ return tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Gradio μΈν„°νŽ˜μ΄μŠ€ (μ΅œμ†Œν™”)
23
+ demo = gr.ChatInterface(
24
+ respond,
25
+ title="CLOVA X 챗봇",
26
+ description="넀이버 HyperCLOVAX-SEED 기반 챗봇"
27
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  if __name__ == "__main__":
30
+ demo.launch(server_name="0.0.0.0", server_port=7860)