File size: 3,489 Bytes
98cc856
2506dd9
98cc856
2506dd9
98cc856
2506dd9
98cc856
 
 
 
2506dd9
98cc856
 
 
 
 
 
 
 
 
 
 
 
 
2506dd9
 
98cc856
 
 
2506dd9
 
 
4342e64
2506dd9
 
98cc856
 
2506dd9
 
98cc856
2506dd9
 
98cc856
2506dd9
98cc856
2506dd9
 
98cc856
2506dd9
98cc856
2506dd9
 
98cc856
 
 
 
2506dd9
 
98cc856
 
2506dd9
98cc856
 
 
 
 
4342e64
 
98cc856
 
 
 
 
2506dd9
98cc856
2506dd9
98cc856
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Import the AutoTokenizer function from the transformers library
from transformers import AutoTokenizer
# Import the pipeline function from the transformers library
from transformers import pipeline
# Import Gradio
import gradio as gr
# Import Transformer
import transformers
# Import pyTorch
import torch

# Define Model Name
model = "arcpolar/Ubuntu_Llama_Chat_7B"
# Setup Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model)

# Llama pipeline learned from Ograbek, K. youtube video and colab note book
# Code from https://colab.research.google.com/drive/1SSv6lzX3Byu50PooYogmiwHqf5PQN68E
# Initialize a text-generation pipeline using Ubuntu_Llama_Chat_7B
Ubuntu_Llama_Chat_pipeline = pipeline(
    "text-generation",  # Specify the task as text-generation
    model=model, # Use Ubuntu_Llama_Chat_7B for the task
    torch_dtype=torch.float16, # Set data type for PyTorch tensors to float16
    device_map="auto", # Automatically choose the computation device
)

# Format Message and System Prompt learned from Ograbek, K. youtube video and colab notebook
# Code from https://colab.research.google.com/drive/1SSv6lzX3Byu50PooYogmiwHqf5PQN68E
# Define the initial prompt for the Llama 2 model
SYSTEM_PROMPT = """<s>[INST] <<SYS>>
You are a helpful bot. Your answers are clear and concise.
<</SYS>>

"""
# Formatting function for message and history
def format_message(message: str, history: list, memory_limit: int = 5) -> str:
    # If history length exceeds memory_limit, keep only the most recent interactions
    if len(history) > memory_limit:
        history = history[-memory_limit:]
    # If there's no history, return the SYSTEM_PROMPT and current message
    if len(history) == 0:
        return SYSTEM_PROMPT + f"{message} [/INST]"
    # Start the formatted message with the SYSTEM_PROMPT and the oldest history item
    formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} </s>"
    # Iterate over remaining history items and format them accordingly
    for user_msg, model_answer in history[1:]:
        formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>"
    # Append the current user message to the formatted string
    formatted_message += f"<s>[INST] {message} [/INST]"
    # Return the fully formatted message string
    return formatted_message

# Generate response learned from Ograbek, K. youtube video and colab notebook
# Code from https://colab.research.google.com/drive/1SSv6lzX3Byu50PooYogmiwHqf5PQN68E
def get_response(message: str, history: list) -> str:
    # Format the user's message and history for input to the Llama model
    query = format_message(message, history)

    # Get a response from the Llama model using the configured parameters
    sequences = Ubuntu_Llama_Chat_pipeline(
        query,
        do_sample=True,             # Enable sampling for response generation
        top_k=10,                   # Limit sampling to top 10 tokens
        num_return_sequences=1,     # Request a single response sequence
        eos_token_id=tokenizer.eos_token_id,  # Specify the end-of-sequence token
        max_length=1024             # Set a maximum length for the response
    )

    # Extract the model's response, excluding the original query
    response = sequences[0]['generated_text'][len(query):].strip()

    # Display the response
    print("Chatbot:", response)

    return response

# Launch a chat interface using the `get_response` function
gr.ChatInterface(get_response).launch()