Spaces:
Runtime error
Runtime error
MS-YUN
commited on
Commit
ยท
9b7601c
1
Parent(s):
b1895c1
Add application file
Browse files
app.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ๋ชจ๋ธ ๋ก๋ฉ
|
| 2 |
+
import torch
|
| 3 |
+
from peft import PeftConfig, PeftModel
|
| 4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 5 |
+
|
| 6 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 7 |
+
|
| 8 |
+
base_model_name = "facebook/opt-350m"
|
| 9 |
+
adapter_model_name = 'msy127/opt-350m-aihubqa-130-dpo-adapter'
|
| 10 |
+
|
| 11 |
+
model = AutoModelForCausalLM.from_pretrained(base_model_name)
|
| 12 |
+
model = PeftModel.from_pretrained(model, adapter_model_name).to(device)
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained(adapter_model_name)
|
| 14 |
+
|
| 15 |
+
# ๋ํ ๋์ ํจ์ (history) - prompt ์๋ฆฌ์ history๊ฐ ๋ค์ด๊ฐ -> dialoGPT๋ ๋ชจ๋ธ ์ง์ด๋ฃ๊ธฐ ์ ์ ์ธ์ฝ๋ฉ์ ํ์๋๋ฐ OPENAI๋ ์ธ์ฝ๋ฉ์ ์ํ๋ค.
|
| 16 |
+
|
| 17 |
+
def predict(input, history):
|
| 18 |
+
history.append({"role": "user", "content": input})
|
| 19 |
+
|
| 20 |
+
# ์ผ๋ฐ๋ชจ๋ธ
|
| 21 |
+
prompt = f"An AI tool that looks at the context and question separated by triple backquotes, finds the answer corresponding to the question in the context, and answers clearly.\n### Input: ```{input}```\n ### Output: "
|
| 22 |
+
inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
| 23 |
+
outputs = model.generate(input_ids=inputs, max_length=256)
|
| 24 |
+
generated_text = tokenizer.decode(outputs[0])
|
| 25 |
+
start_idx = len(prompt) + len('</s>')
|
| 26 |
+
stop_first_idx = generated_text.find("### Input:") # ์ฒซ ๋ฒ์งธ "### Input:"์ ์ฐพ์ต๋๋ค.
|
| 27 |
+
stop_idx = generated_text.find("### Input:", stop_first_idx + 1) # ์ฒซ ๋ฒ์งธ "### Input:" ์ดํ์ ๋ฌธ์์ด์์ ๋ค์ "### Input:"์ ์ฐพ์ต๋๋ค.
|
| 28 |
+
# print(start_idx , stop_idx)
|
| 29 |
+
# print(generated_text)
|
| 30 |
+
if stop_idx != -1:
|
| 31 |
+
response = generated_text[start_idx:stop_idx] # prompt ๋ค์ ์๋ ์๋กญ๊ฒ ์์ฑ๋ ํ
์คํธ๋ง ("### Input:" ์ ๊น์ง) ๊ฐ์ ธ์ต๋๋ค.
|
| 32 |
+
|
| 33 |
+
# ๋์
|
| 34 |
+
history.append({"role": "assistant", "content": response})
|
| 35 |
+
# messages = [(history[i]["content"], history[i+1]["content"]) for i in range(1, len(history), 2)]
|
| 36 |
+
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history) - 1, 2)]
|
| 37 |
+
|
| 38 |
+
return messages, history
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Gradio ์ธํฐํ์ด์ค ์ค์
|
| 42 |
+
import gradio as gr
|
| 43 |
+
with gr.Blocks() as demo:
|
| 44 |
+
chatbot = gr.Chatbot(label="ChatBot")
|
| 45 |
+
|
| 46 |
+
state = gr.State([
|
| 47 |
+
{"role": "system", "content": "๋น์ ์ ์น์ ํ ์ธ๊ณต์ง๋ฅ ์ฑ๋ด์
๋๋ค. ์
๋ ฅ์ ๋ํด ์งง๊ณ ๊ฐ๊ฒฐํ๊ณ ์น์ ํ๊ฒ ๋๋ตํด์ฃผ์ธ์."}])
|
| 48 |
+
|
| 49 |
+
with gr.Row():
|
| 50 |
+
txt = gr.Textbox(show_label=False, placeholder="์ฑ๋ด์๊ฒ ์๋ฌด๊ฑฐ๋ ๋ฌผ์ด๋ณด์ธ์").style(container=False)
|
| 51 |
+
# txt.submit(predict, [txt, state], [chatbot, state])
|
| 52 |
+
|
| 53 |
+
txt.submit(predict, [txt, state], [chatbot, state])
|
| 54 |
+
|
| 55 |
+
# demo.launch(debug=True, share=True)
|
| 56 |
+
demo.launch()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# from PIL import Image
|
| 60 |
+
# import gradio as gr
|
| 61 |
+
# interface = gr.Interface(
|
| 62 |
+
# fn=classify_image,
|
| 63 |
+
# inputs=gr.components.Image(type="pil", label="Upload an Image"),
|
| 64 |
+
# outputs="text",
|
| 65 |
+
# live=True
|
| 66 |
+
# )
|
| 67 |
+
# interface.launch()
|