| | import gradio as gr |
| | import os |
| | from huggingface_hub import InferenceClient |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | COHERE_MODEL = "CohereForAI/c4ai-command-r-plus-08-2024" |
| | SYSTEM_MESSAGE = "๋ธ๋ก๊ทธ ๊ธ์ ์์ฑํด์ฃผ์ธ์." |
| | MAX_TOKENS = 4000 |
| | TEMPERATURE = 0.7 |
| | TOP_P = 0.95 |
| |
|
| | def get_client(hf_token): |
| | """ |
| | HuggingFace InferenceClient ์์ฑ. |
| | """ |
| | if not hf_token: |
| | raise ValueError("HuggingFace API ํ ํฐ์ด ํ์ํฉ๋๋ค.") |
| |
|
| | return InferenceClient(COHERE_MODEL, token=hf_token) |
| |
|
| | def respond_cohere_qna(question, system_message, max_tokens, temperature, top_p, hf_token): |
| | """ |
| | Cohere ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ง๋ฌธ์ ๋ํ ์๋ต ์์ฑ. |
| | """ |
| | client = get_client(hf_token) |
| |
|
| | messages = [ |
| | {"role": "system", "content": system_message}, |
| | {"role": "user", "content": question} |
| | ] |
| |
|
| | response = client.chat_completion( |
| | messages=messages, |
| | max_tokens=max_tokens, |
| | temperature=temperature, |
| | top_p=top_p, |
| | ) |
| | return response.choices[0].message.content |
| |
|
| | |
| | def generate_blog(tone, ref1, ref2, ref3): |
| | """ |
| | ๋ธ๋ก๊ทธ ๊ธ์ ์์ฑํ๋ ํจ์. |
| | HuggingFace ํ ํฐ์ ํ๊ฒฝ ๋ณ์์์ ์ง์ ๊ฐ์ ธ์ต๋๋ค. |
| | """ |
| | |
| | hf_token_value = os.getenv("HF_TOKEN") |
| | if not hf_token_value: |
| | return "HuggingFace ํ ํฐ์ด ์ค์ ๋์ง ์์์ต๋๋ค." |
| |
|
| | |
| | prompt = f"๋งํฌ: {tone}\n์ฐธ์กฐ๋ฌธ 1: {ref1}\n์ฐธ์กฐ๋ฌธ 2: {ref2}\n์ฐธ์กฐ๋ฌธ 3: {ref3}" |
| |
|
| | |
| | return respond_cohere_qna( |
| | question=prompt, |
| | system_message=SYSTEM_MESSAGE, |
| | max_tokens=MAX_TOKENS, |
| | temperature=TEMPERATURE, |
| | top_p=TOP_P, |
| | hf_token=hf_token_value |
| | ) |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("# ๋ธ๋ก๊ทธ ์์ฑ๊ธฐ") |
| |
|
| | |
| | hf_token = os.getenv("HF_TOKEN") |
| |
|
| | |
| | with gr.Row(): |
| | tone = gr.Radio( |
| | choices=["์น๊ทผํ๊ฒ", "์ผ๋ฐ์ ์ธ", "์ ๋ฌธ์ ์ธ"], |
| | label="๋งํฌ ๋ฐ๊พธ๊ธฐ", |
| | value="์ผ๋ฐ์ ์ธ" |
| | ) |
| |
|
| | ref1 = gr.Textbox(label="์ฐธ์กฐ๋ฌธ 1", lines=3) |
| | ref2 = gr.Textbox(label="์ฐธ์กฐ๋ฌธ 2", lines=3) |
| | ref3 = gr.Textbox(label="์ฐธ์กฐ๋ฌธ 3", lines=3) |
| |
|
| | |
| | answer_output = gr.Textbox(label="์์ฑ๋ ๋ธ๋ก๊ทธ ๊ธ", lines=10, interactive=False) |
| |
|
| | |
| | submit_button = gr.Button("์์ฑ") |
| |
|
| | submit_button.click( |
| | fn=generate_blog, |
| | inputs=[tone, ref1, ref2, ref3], |
| | outputs=answer_output |
| | ) |
| |
|
| | |
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|