File size: 4,319 Bytes
b722fbb
66043b9
 
b722fbb
538b5f4
66043b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b722fbb
 
66043b9
b722fbb
66043b9
 
 
 
 
 
 
 
 
 
b722fbb
66043b9
b722fbb
66043b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b722fbb
 
66043b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b722fbb
66043b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318bf1e
66043b9
 
 
 
 
b722fbb
66043b9
 
 
b722fbb
66043b9
 
 
 
 
b722fbb
318bf1e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
from transformers import pipeline, Conversation
import torch
from huggingface_hub import login

# --- Configuration ---

MODEL_CHOICES = [
    "mistralai/Mistral-7B-Instruct-v0.2",  # Good balance
    "meta-llama/Llama-2-70b-chat-hf",  # Higher quality, requires HF token, more resources
    "mistralai/Mixtral-8x7B-Instruct-v0.1",  # Potentially best quality, high resources
    "codellama/CodeLlama-70b-Instruct-hf"  # Best for code, high resources
]

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"


# --- Helper Functions ---

def load_model(model_name, hf_token=None):
    """Loads the model and tokenizer, handling authentication."""
    try:
        if hf_token:
            login(token=hf_token)

        # Use a pipeline for easier interaction
        pipe = pipeline(
            "conversational",
            model=model_name,
            device=DEVICE,  # Move to GPU if available
            torch_dtype=torch.bfloat16,  # Use bfloat16 for faster inference (if supported)
            trust_remote_code=True,  # Important for custom models
            use_flash_attention_2=True,  # Use flash attention if available
        )
        return pipe, "Model loaded successfully!"
    except Exception as e:
        return None, f"Error loading model: {e}"


def generate_response(prompt, chat_history, model_name, hf_token=None):
    """Generates a response using the conversational pipeline."""

    # Use a dictionary to store loaded models for faster switching
    if not hasattr(generate_response, "loaded_models"):
        generate_response.loaded_models = {}

    if model_name not in generate_response.loaded_models:
        pipe, load_status = load_model(model_name, hf_token)
        if pipe is None:
            return load_status, chat_history
        generate_response.loaded_models[model_name] = pipe
        print(f"Model {model_name} loaded.")  # Debugging message
    else:
        print(f"Using cached model {model_name}.")  # Debugging message

    pipe = generate_response.loaded_models[model_name]

    try:
        # Convert Gradio chat history to transformers Conversation format
        conversation = Conversation()
        for user_message, bot_message in chat_history:
            conversation.add_message({"role": "user", "content": user_message})
            if bot_message:  # Handle case where bot hasn't responded yet
                conversation.add_message({"role": "assistant", "content": bot_message})
        conversation.add_message({"role": "user", "content": prompt})

        # Generation parameters (adjust these!)
        generation_kwargs = {
            "max_new_tokens": 512,
            "do_sample": True,
            "top_p": 0.95,
            "temperature": 0.7,
            "repetition_penalty": 1.1
        }

        # Generate the response
        response = pipe(conversation, **generation_kwargs)

        # Extract the bot's response from the Conversation object
        bot_response = response.messages[-1]["content"]

        # Update the chat history
        chat_history.append((prompt, bot_response))

        return "", chat_history

    except Exception as e:
        return f"Error during generation: {e}", chat_history


# --- Gradio Interface ---

with gr.Blocks(title="Chat with a Powerful AI") as iface:
    gr.Markdown(
        """
        # Chat with Different AI Models
        This Space demonstrates a chatbot that allows you to select from different AI models.
        Choose a model from the dropdown and start chatting!
        """
    )

    model_selection = gr.Dropdown(
        choices=MODEL_CHOICES,
        value=MODEL_CHOICES[0],  # Default model
        label="Select Model",
        info="Choose the AI model you want to chat with."
    )

    hf_token_input = gr.Textbox(
        label="Hugging Face Token (Optional, for gated models)",
        type="password",
        placeholder="Enter your Hugging Face token (if required)",
    )

    chatbot = gr.Chatbot(label="Chat History", height=500)  # Set a reasonable height
    msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
    clear = gr.ClearButton([msg, chatbot])

    msg.submit(
        generate_response,
        [msg, chatbot, model_selection, hf_token_input],
        [msg, chatbot],
    )

iface.launch()