| import os |
| import time |
| import gradio as gr |
| from unsloth import FastLanguageModel |
| from transformers import TrainingArguments, TextStreamer |
| from datasets import load_dataset |
|
|
| |
| os.environ["OMP_NUM_THREADS"] = "1" |
|
|
| |
| model_name = "EpistemeAI/Episteme-gptoss-20b-RL" |
| print(f"Loading model: {model_name}") |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=model_name, |
| load_in_4bit=False, |
| dtype=None, |
| device_map="auto", |
| ) |
|
|
| |
| conversation_history = [] |
| training_data = [] |
| TRAIN_EVERY_N_PROMPTS = 8 |
| training_status = "Idle" |
|
|
| |
| def chat_response(user_input): |
| global training_data, training_status |
|
|
| |
| conversation_history.append({"role": "user", "content": user_input}) |
|
|
| |
| full_prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation_history]) |
| inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) |
|
|
| |
| streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| model.generate(**inputs, streamer=streamer, max_new_tokens=200) |
| response_text = streamer.text.strip() |
|
|
| |
| conversation_history.append({"role": "assistant", "content": response_text}) |
| training_data.append({"instruction": user_input, "expected_output": response_text}) |
|
|
| |
| if len(training_data) >= TRAIN_EVERY_N_PROMPTS: |
| training_status = "Training..." |
| yield [ |
| conversation_history, |
| "Auto-training in progress... this may take a moment." |
| ] |
| train_model(training_data) |
| training_data.clear() |
| training_status = "Idle" |
|
|
| yield [conversation_history, f"Assistant: {response_text}"] |
|
|
| |
| def train_model(data_samples): |
| dataset = { |
| "instruction": [d["instruction"] for d in data_samples], |
| "expected_output": [d["expected_output"] for d in data_samples], |
| } |
| training_args = TrainingArguments( |
| output_dir="./results", |
| per_device_train_batch_size=2, |
| num_train_epochs=1, |
| learning_rate=2e-4, |
| logging_dir="./logs", |
| save_strategy="no", |
| optim="adamw_bnb_8bit", |
| fp16=True, |
| lr_scheduler_type="cosine", |
| ) |
|
|
| model.train_model( |
| dataset=dataset, |
| training_args=training_args, |
| lora_r=8, |
| lora_alpha=16, |
| lora_dropout=0.1, |
| ) |
| print("✅ Model retrained successfully") |
|
|
| |
| def get_training_status(): |
| return f"Training status: {training_status}" |
|
|
| |
| def clear_chat(): |
| global conversation_history |
| conversation_history = [] |
| return [], "Chat cleared." |
|
|
| |
| with gr.Blocks(title="🧠 Auto-Training Chatbot", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("### 🤖 Adaptive AI Chatbot (with Auto-Retraining)") |
|
|
| with gr.Row(): |
| chatbot = gr.Chatbot( |
| height=500, |
| label="Chat Interface", |
| type="messages" |
| ) |
| with gr.Column(): |
| status_display = gr.Textbox( |
| label="Training Status", |
| value="Idle", |
| interactive=False |
| ) |
| clear_btn = gr.Button("Clear Chat") |
|
|
| user_input = gr.Textbox( |
| label="Your message", |
| placeholder="Type your question or message here..." |
| ) |
| send_btn = gr.Button("Send") |
|
|
| |
| send_btn.click( |
| fn=chat_response, |
| inputs=[user_input], |
| outputs=[chatbot, status_display], |
| ) |
|
|
| clear_btn.click( |
| fn=clear_chat, |
| outputs=[chatbot, status_display], |
| ) |
|
|
| |
| timer = gr.Timer(interval=2.0, active=True) |
| timer.tick(fn=get_training_status, outputs=[status_display]) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |