| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import spaces |
| | import torch |
| |
|
| | import gradio as gr |
| |
|
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoModelForSeq2SeqLM, |
| | pipeline |
| | ) |
| |
|
| | from huggingface_hub import InferenceClient |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"current device is: {device}") |
| |
|
| | |
| |
|
| | class chat_engine_hf_api: |
| |
|
| | def __init__(self): |
| | self.client = InferenceClient( |
| | "microsoft/Phi-3.5-mini-instruct", |
| | |
| | token=os.environ['HF_TOKEN_API'] |
| | ) |
| |
|
| | def answer(self, message, history): |
| | return self.client.chat_completion( |
| | history + [{"role": "user", "content": f"tu es un assistant francophone. Répond en une seule phrase sans formattage.\n{message}"}], |
| | max_tokens=512, |
| | temperature = 0.5).choices[0].message.content |
| |
|
| | chat_engine = chat_engine_hf_api() |
| |
|
| | |
| |
|
| | fw_modelcard = "amurienne/gallek-m2m100-v0.2" |
| | bw_modelcard = "amurienne/kellag-m2m100-v0.2" |
| |
|
| | fw_model = AutoModelForSeq2SeqLM.from_pretrained(fw_modelcard) |
| | fw_tokenizer = AutoTokenizer.from_pretrained(fw_modelcard) |
| |
|
| | fw_translation_pipeline = pipeline("translation", model=fw_model, tokenizer=fw_tokenizer, src_lang='fr', tgt_lang='br', max_length=400, device="cpu") |
| |
|
| | bw_model = AutoModelForSeq2SeqLM.from_pretrained(bw_modelcard) |
| | bw_tokenizer = AutoTokenizer.from_pretrained(bw_modelcard) |
| |
|
| | bw_translation_pipeline = pipeline("translation", model=bw_model, tokenizer=bw_tokenizer, src_lang='br', tgt_lang='fr', max_length=400, device="cpu") |
| |
|
| | |
| | @spaces.GPU |
| | def translate(text, forward: bool): |
| | if forward: |
| | return fw_translation_pipeline("traduis de français en breton: " + text)[0]['translation_text'] |
| | else: |
| | return bw_translation_pipeline("treiñ eus ar galleg d'ar brezhoneg: " + text)[0]['translation_text'] |
| |
|
| | |
| | max_history_length = 3 |
| |
|
| | |
| | native_chat_history = [] |
| |
|
| | |
| | example_queries = [{"text" : "Petra eo ar rekipe krampouezh ?"}, {"text": "Pelec'h emañ Pariz ?"}, {"text" : "Petra eo kêr vrasañ Breizh ?"}, {"text" : "Kont din ur farsadenn bugel ?"}] |
| |
|
| | with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| | |
| | gr.Markdown("# BreizhBot\n## Breton Chatbot (Translation based)\nPart of the [GweLLM](https://github.com/blackccpie/GweLLM) project") |
| | |
| | chatbot = gr.Chatbot( |
| | label="Chat", |
| | placeholder="Degemer mat, petra a c'hellan ober evidoc'h ?", |
| | examples=example_queries, |
| | type="messages") |
| | msg = gr.Textbox(label='User Input') |
| |
|
| | def clear(chat_history): |
| | """ |
| | Handles clearing chat |
| | """ |
| | chat_history.clear() |
| | native_chat_history.clear() |
| |
|
| | chatbot.clear(clear, inputs=[chatbot]) |
| |
|
| | def example_input(evt: gr.SelectData): |
| | """ |
| | Handles example input selection |
| | """ |
| | return evt.value["text"] |
| | |
| | def user_input(message, chat_history): |
| | """ |
| | Handles instant display of the user query (without waiting for model answer) |
| | """ |
| | chat_history.append({"role": "user", "content": message}) |
| | return chat_history |
| |
|
| | def respond(message, chat_history): |
| | """ |
| | Handles bot response generation |
| | """ |
| |
|
| | global native_chat_history |
| |
|
| | fr_message = translate(message, forward=False) |
| | print(f"user fr -> {fr_message}") |
| |
|
| | bot_fr_message = chat_engine.answer(fr_message, native_chat_history) |
| | print(f"bot fr -> {bot_fr_message}") |
| | bot_br_message = translate( bot_fr_message, forward=True) |
| | print(f"bot br -> {bot_br_message}") |
| |
|
| | chat_history.append({"role": "assistant", "content": bot_br_message}) |
| |
|
| | native_chat_history.append({"role": "user", "content": fr_message}) |
| | native_chat_history.append({"role": "assistant", "content": bot_fr_message}) |
| |
|
| | |
| | if len(chat_history) > max_history_length * 2: |
| | chat_history = chat_history[-max_history_length * 2:] |
| | native_chat_history = native_chat_history[-max_history_length * 2:] |
| |
|
| | return "", chat_history |
| |
|
| | chatbot.example_select(example_input, None, msg).then(user_input, [msg, chatbot], chatbot).then(respond, [msg, chatbot], [msg, chatbot]) |
| | |
| | msg.submit(user_input, [msg, chatbot], chatbot).then(respond, [msg, chatbot], [msg, chatbot]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |