File size: 4,881 Bytes
4480d43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch

class LLMModule:
    def __init__(self):
        self.model_options = {
            "TinyLlama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
            "Phi-2": "microsoft/phi-2",
            "Qwen 0.5B": "Qwen/Qwen2.5-0.5B-Instruct"
        }
        self.current_model = None
        self.pipe = None
        self.chat_history = []

    def load_model(self, model_name):
        """Load LLM model"""
        try:
            model_id = self.model_options[model_name]
            device = "cuda" if torch.cuda.is_available() else "cpu"

            self.pipe = pipeline(
                "text-generation",
                model=model_id,
                device=device,
                torch_dtype=torch.float16 if device == "cuda" else torch.float32
            )
            self.current_model = model_name
            self.chat_history = []
            return f"✓ Loaded {model_name} on {device}"
        except Exception as e:
            return f"✗ Error loading model: {str(e)}"

    def generate_response(self, message, max_tokens, temperature):
        """Generate LLM response"""
        if self.pipe is None:
            return "⚠ Please load a model first", []

        if not message.strip():
            return "⚠ Please enter a message", self.chat_history

        try:
            # Add user message to history
            self.chat_history.append({"role": "user", "content": message})

            # Generate response
            response = self.pipe(
                message,
                max_new_tokens=int(max_tokens),
                temperature=float(temperature),
                do_sample=True,
                top_p=0.9
            )

            assistant_message = response[0]["generated_text"]

            # Clean up if the model repeats the input
            if assistant_message.startswith(message):
                assistant_message = assistant_message[len(message):].strip()

            # Add assistant response to history
            self.chat_history.append({"role": "assistant", "content": assistant_message})

            # Format for chatbot display
            chat_display = [(h["content"], self.chat_history[i+1]["content"])
                          for i, h in enumerate(self.chat_history[::2])
                          if i*2+1 < len(self.chat_history)]

            return "", chat_display
        except Exception as e:
            return f"✗ Error generating response: {str(e)}", self.chat_history

    def clear_history(self):
        """Clear chat history"""
        self.chat_history = []
        return [], ""

    def create_interface(self):
        """Create Gradio interface for LLM testing"""
        with gr.Column() as interface:
            gr.Markdown("## 🤖 LLM Testing")

            with gr.Row():
                model_selector = gr.Dropdown(
                    choices=list(self.model_options.keys()),
                    value="Qwen 0.5B",
                    label="Select LLM Model"
                )
                load_btn = gr.Button("Load Model", variant="primary")

            status = gr.Textbox(label="Status", interactive=False)

            gr.Markdown("### Chat Interface")
            chatbot = gr.Chatbot(label="Conversation", height=400)

            with gr.Row():
                message_input = gr.Textbox(
                    label="Message",
                    placeholder="Type your message...",
                    scale=4
                )
                send_btn = gr.Button("Send", variant="secondary", scale=1)

            with gr.Row():
                max_tokens = gr.Slider(
                    minimum=50,
                    maximum=500,
                    value=150,
                    step=10,
                    label="Max Tokens"
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=1.5,
                    value=0.7,
                    step=0.1,
                    label="Temperature"
                )

            clear_btn = gr.Button("Clear Chat", variant="stop")

            load_btn.click(
                fn=self.load_model,
                inputs=[model_selector],
                outputs=[status]
            )

            send_btn.click(
                fn=self.generate_response,
                inputs=[message_input, max_tokens, temperature],
                outputs=[message_input, chatbot]
            )

            message_input.submit(
                fn=self.generate_response,
                inputs=[message_input, max_tokens, temperature],
                outputs=[message_input, chatbot]
            )

            clear_btn.click(
                fn=self.clear_history,
                outputs=[chatbot, message_input]
            )

        return interface