| |
| |
| |
| from typing import TypedDict |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.runnables import RunnablePassthrough |
| from langchain_huggingface import HuggingFacePipeline |
| |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from transformers import pipeline |
| |
| import torch |
| |
| import gradio as gr |
| |
| from asyncio import sleep |
| |
| from vector_store import get_document_database |
|
|
|
|
| class ChatMessage(TypedDict): |
| role: str |
| metadata: dict |
| content: str |
|
|
|
|
| |
| |
| MODEL_NAME = "google/gemma-2-2b-it" |
|
|
| |
| |
| |
| |
| |
| |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| |
| |
| torch_dtype=torch.bfloat16 |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
| text_generation_pipeline = pipeline( |
| model=model, |
| tokenizer=tokenizer, |
| task="text-generation", |
| temperature=0.2, |
| do_sample=True, |
| repetition_penalty=1.1, |
| return_full_text=True, |
| max_new_tokens=400, |
| ) |
|
|
| llm = HuggingFacePipeline(pipeline=text_generation_pipeline) |
|
|
| |
| print("creating the document database") |
| db = get_document_database("learning_material/*/*/*") |
| print("Document database is ready") |
|
|
|
|
| def generate_prompt(message_history: list[ChatMessage], max_history=5): |
| |
| |
| |
| prompt_template = ChatPromptTemplate([ |
| ("system", """You are 'thesizer', a HAMK thesis assistant. |
| You will help the user with technicalities on writing a thesis |
| for hamk. If you can't find the answer from the context given to you, |
| you will tell the user that you cannot assist with the specific topic. |
| You speak both Finnish and English by following the user's language. |
| Continue the conversation with a single response from the AI."""), |
| ("system", "{context}"), |
| ]) |
|
|
| |
| if len(message_history) < 4: |
| prompt_template.append( |
| ("assistant", "Hei! Kuinka voin auttaa opinnäytetyösi kanssa?"), |
| ) |
| prompt_template.append( |
| ("assistant", "Hello! How can I help you with authoring your thesis?"), |
| ) |
|
|
| |
| for message in message_history[-max_history:]: |
| prompt_template.append( |
| (message["role"], message["content"]) |
| ) |
|
|
| |
| |
| prompt_template.append( |
| ("assistant", "<RESPONSE>:") |
| ) |
|
|
| return prompt_template |
|
|
|
|
| async def generate_answer(message_history: list[ChatMessage]): |
|
|
| |
| n_of_best_results = 4 |
| retriever = db.as_retriever( |
| search_type="similarity", search_kwargs={"k": n_of_best_results}) |
|
|
| print("generating prompt") |
| prompt = generate_prompt(message_history, max_history=5) |
| print("prompt is ready") |
|
|
| |
| |
| retrieval_chain = ( |
| {"context": retriever, "user_input": RunnablePassthrough()} |
| | prompt |
| | llm |
| | StrOutputParser() |
| ) |
|
|
| |
| user_input = message_history[-1]["content"] |
| print("invoking") |
| response = retrieval_chain.invoke(user_input) |
| print("response recieved from invoke") |
|
|
| |
| print("=====raw response=====") |
| print(response) |
|
|
| |
| |
| parsed_answer = response.split( |
| str(user_input) |
| ).pop().split("<RESPONSE>:", 1).pop().strip() |
|
|
| print(repr(parsed_answer)) |
|
|
| |
| |
| return parsed_answer.replace("\n\n", "<br>") |
|
|
|
|
| def update_chat(user_message: str, history: list): |
| return "", history + [{"role": "user", "content": user_message}] |
|
|
|
|
| async def handle_conversation( |
| history: list[ChatMessage], |
| characters_per_second=80 |
| ): |
| bot_message = await generate_answer(history) |
| new_message: ChatMessage = {"role": "assistant", |
| "metadata": {"title": None}, |
| "content": ""} |
| history.append(new_message) |
| for character in bot_message: |
| history[-1]['content'] += character |
| yield history |
| await sleep(1 / characters_per_second) |
|
|
|
|
| def create_interface(): |
| with gr.Blocks() as interface: |
| gr.Markdown("# 📄 Thesizer: HAMK Thesis Assistant") |
| gr.Markdown("Ask for help with authoring the HAMK thesis!") |
|
|
| gr.Markdown("## Chat interface") |
|
|
| with gr.Column(): |
| chatbot = gr.Chatbot(type="messages") |
|
|
| with gr.Row(): |
| user_input = gr.Textbox( |
| label="You:", |
| placeholder="Type your message here...", |
| show_label=False |
| ) |
| send_button = gr.Button("Send") |
|
|
| |
| send_button.click( |
| fn=update_chat, |
| inputs=[user_input, chatbot], |
| outputs=[user_input, chatbot], |
| queue=False |
| ).then( |
| fn=handle_conversation, |
| inputs=chatbot, |
| outputs=chatbot |
| ) |
|
|
| |
| user_input.submit( |
| fn=update_chat, |
| inputs=[user_input, chatbot], |
| outputs=[user_input, chatbot], |
| queue=False |
| ).then( |
| fn=handle_conversation, |
| inputs=chatbot, |
| outputs=chatbot |
| ) |
|
|
| return interface |
|
|
|
|
| if __name__ == "__main__": |
| create_interface().launch() |
|
|