Spaces:
Build error
Build error
| # ๋ชจ๋ธ ๋ก๋ฉ | |
| import torch | |
| from peft import PeftConfig, PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| base_model_name = "facebook/opt-350m" | |
| adapter_model_name = 'msy127/opt-350m-aihubqa-130-dpo-adapter' | |
| model = AutoModelForCausalLM.from_pretrained(base_model_name) | |
| model = PeftModel.from_pretrained(model, adapter_model_name).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(adapter_model_name) | |
| # ๋ํ ๋์ ํจ์ (history) - prompt ์๋ฆฌ์ history๊ฐ ๋ค์ด๊ฐ -> dialoGPT๋ ๋ชจ๋ธ ์ง์ด๋ฃ๊ธฐ ์ ์ ์ธ์ฝ๋ฉ์ ํ์๋๋ฐ OPENAI๋ ์ธ์ฝ๋ฉ์ ์ํ๋ค. | |
| def predict(input, history): | |
| history.append({"role": "user", "content": input}) | |
| # ์ผ๋ฐ๋ชจ๋ธ | |
| prompt = f"An AI tool that looks at the context and question separated by triple backquotes, finds the answer corresponding to the question in the context, and answers clearly.\n### Input: ```{input}```\n ### Output: " | |
| inputs = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| outputs = model.generate(input_ids=inputs, max_length=256) | |
| generated_text = tokenizer.decode(outputs[0]) | |
| start_idx = len(prompt) + len('</s>') | |
| stop_first_idx = generated_text.find("### Input:") # ์ฒซ ๋ฒ์งธ "### Input:"์ ์ฐพ์ต๋๋ค. | |
| stop_idx = generated_text.find("### Input:", stop_first_idx + 1) # ์ฒซ ๋ฒ์งธ "### Input:" ์ดํ์ ๋ฌธ์์ด์์ ๋ค์ "### Input:"์ ์ฐพ์ต๋๋ค. | |
| # print(start_idx , stop_idx) | |
| # print(generated_text) | |
| if stop_idx != -1: | |
| response = generated_text[start_idx:stop_idx] # prompt ๋ค์ ์๋ ์๋กญ๊ฒ ์์ฑ๋ ํ ์คํธ๋ง ("### Input:" ์ ๊น์ง) ๊ฐ์ ธ์ต๋๋ค. | |
| # ๋์ | |
| history.append({"role": "assistant", "content": response}) | |
| # messages = [(history[i]["content"], history[i+1]["content"]) for i in range(1, len(history), 2)] | |
| messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history) - 1, 2)] | |
| return messages, history | |
| # Gradio ์ธํฐํ์ด์ค ์ค์ | |
| import gradio as gr | |
| with gr.Blocks() as demo: | |
| chatbot = gr.Chatbot(label="ChatBot") | |
| state = gr.State([ | |
| {"role": "system", "content": "๋น์ ์ ์น์ ํ ์ธ๊ณต์ง๋ฅ ์ฑ๋ด์ ๋๋ค. ์ ๋ ฅ์ ๋ํด ์งง๊ณ ๊ฐ๊ฒฐํ๊ณ ์น์ ํ๊ฒ ๋๋ตํด์ฃผ์ธ์."}]) | |
| with gr.Row(): | |
| txt = gr.Textbox(show_label=False, placeholder="์ฑ๋ด์๊ฒ ์๋ฌด๊ฑฐ๋ ๋ฌผ์ด๋ณด์ธ์").style(container=False) | |
| # txt.submit(predict, [txt, state], [chatbot, state]) | |
| txt.submit(predict, [txt, state], [chatbot, state]) | |
| # demo.launch(debug=True, share=True) | |
| demo.launch() | |
| # from PIL import Image | |
| # import gradio as gr | |
| # interface = gr.Interface( | |
| # fn=classify_image, | |
| # inputs=gr.components.Image(type="pil", label="Upload an Image"), | |
| # outputs="text", | |
| # live=True | |
| # ) | |
| # interface.launch() |