Torstens_agent / llm.py
eduard76's picture
Upload 6 files
4480d43 verified
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
class LLMModule:
def __init__(self):
self.model_options = {
"TinyLlama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Phi-2": "microsoft/phi-2",
"Qwen 0.5B": "Qwen/Qwen2.5-0.5B-Instruct"
}
self.current_model = None
self.pipe = None
self.chat_history = []
def load_model(self, model_name):
"""Load LLM model"""
try:
model_id = self.model_options[model_name]
device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipe = pipeline(
"text-generation",
model=model_id,
device=device,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
self.current_model = model_name
self.chat_history = []
return f"βœ“ Loaded {model_name} on {device}"
except Exception as e:
return f"βœ— Error loading model: {str(e)}"
def generate_response(self, message, max_tokens, temperature):
"""Generate LLM response"""
if self.pipe is None:
return "⚠ Please load a model first", []
if not message.strip():
return "⚠ Please enter a message", self.chat_history
try:
# Add user message to history
self.chat_history.append({"role": "user", "content": message})
# Generate response
response = self.pipe(
message,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
do_sample=True,
top_p=0.9
)
assistant_message = response[0]["generated_text"]
# Clean up if the model repeats the input
if assistant_message.startswith(message):
assistant_message = assistant_message[len(message):].strip()
# Add assistant response to history
self.chat_history.append({"role": "assistant", "content": assistant_message})
# Format for chatbot display
chat_display = [(h["content"], self.chat_history[i+1]["content"])
for i, h in enumerate(self.chat_history[::2])
if i*2+1 < len(self.chat_history)]
return "", chat_display
except Exception as e:
return f"βœ— Error generating response: {str(e)}", self.chat_history
def clear_history(self):
"""Clear chat history"""
self.chat_history = []
return [], ""
def create_interface(self):
"""Create Gradio interface for LLM testing"""
with gr.Column() as interface:
gr.Markdown("## πŸ€– LLM Testing")
with gr.Row():
model_selector = gr.Dropdown(
choices=list(self.model_options.keys()),
value="Qwen 0.5B",
label="Select LLM Model"
)
load_btn = gr.Button("Load Model", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### Chat Interface")
chatbot = gr.Chatbot(label="Conversation", height=400)
with gr.Row():
message_input = gr.Textbox(
label="Message",
placeholder="Type your message...",
scale=4
)
send_btn = gr.Button("Send", variant="secondary", scale=1)
with gr.Row():
max_tokens = gr.Slider(
minimum=50,
maximum=500,
value=150,
step=10,
label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature"
)
clear_btn = gr.Button("Clear Chat", variant="stop")
load_btn.click(
fn=self.load_model,
inputs=[model_selector],
outputs=[status]
)
send_btn.click(
fn=self.generate_response,
inputs=[message_input, max_tokens, temperature],
outputs=[message_input, chatbot]
)
message_input.submit(
fn=self.generate_response,
inputs=[message_input, max_tokens, temperature],
outputs=[message_input, chatbot]
)
clear_btn.click(
fn=self.clear_history,
outputs=[chatbot, message_input]
)
return interface