Spaces:
Sleeping
Sleeping
| """Gradio App - REPLACED with chat interface""" | |
| import gradio as gr | |
| import tensorflow as tf | |
| import os | |
| import json | |
| from model import VedaProgrammingLLM | |
| from tokenizer import VedaTokenizer | |
| from database import db | |
| from train import VedaTrainer | |
| from config import MODEL_DIR | |
| # Global state | |
| model = None | |
| tokenizer = None | |
| conversation_history = [] | |
| current_conv_id = -1 | |
| def initialize(): | |
| """Initialize the assistant""" | |
| global model, tokenizer | |
| print("ποΈ Initializing Veda Programming Assistant...") | |
| config_path = os.path.join(MODEL_DIR, "config.json") | |
| if os.path.exists(config_path): | |
| print("Loading existing model...") | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| tokenizer = VedaTokenizer() | |
| tokenizer.load(os.path.join(MODEL_DIR, "tokenizer.json")) | |
| model = VedaProgrammingLLM( | |
| vocab_size=config['vocab_size'], | |
| max_length=config['max_length'], | |
| d_model=config['d_model'], | |
| num_heads=config['num_heads'], | |
| num_layers=config['num_layers'], | |
| ff_dim=config['ff_dim'] | |
| ) | |
| dummy = tf.zeros((1, config['max_length']), dtype=tf.int32) | |
| model(dummy) | |
| model.load_weights(os.path.join(MODEL_DIR, "weights.h5")) | |
| print("β Model loaded!") | |
| else: | |
| print("Training new model (this takes a few minutes)...") | |
| trainer = VedaTrainer() | |
| trainer.train(epochs=15) | |
| model = trainer.model | |
| tokenizer = trainer.tokenizer | |
| print("β Model trained!") | |
| def clean_response(text: str) -> str: | |
| """Clean the response""" | |
| # Handle code blocks | |
| text = text.replace("<CODE>", "\n```python\n") | |
| text = text.replace("<ENDCODE>", "\n```\n") | |
| # Remove special tokens | |
| for token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]: | |
| text = text.replace(token, "") | |
| # Clean whitespace | |
| lines = text.split('\n') | |
| cleaned = [] | |
| empty_count = 0 | |
| for line in lines: | |
| if line.strip() == '': | |
| empty_count += 1 | |
| if empty_count <= 2: | |
| cleaned.append(line) | |
| else: | |
| empty_count = 0 | |
| cleaned.append(line) | |
| return '\n'.join(cleaned).strip() | |
| def generate_response(user_input: str, temperature: float = 0.7, | |
| max_tokens: int = 200) -> str: | |
| """Generate a response""" | |
| global current_conv_id | |
| if model is None: | |
| return "β³ Model is loading..." | |
| if not user_input.strip(): | |
| return "Please type a message!" | |
| try: | |
| # Build context from history (last 3 exchanges) | |
| context = "" | |
| for msg in conversation_history[-3:]: | |
| context += f"<USER> {msg['user']}\n<ASSISTANT> {msg['assistant']}\n" | |
| # Add current input | |
| prompt = context + f"<USER> {user_input}\n<ASSISTANT>" | |
| # Encode | |
| tokens = tokenizer.encode(prompt) | |
| # Truncate if too long | |
| if len(tokens) > model.max_length - max_tokens: | |
| tokens = tokens[-(model.max_length - max_tokens):] | |
| # Generate | |
| generated = model.generate( | |
| tokens, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=50, | |
| top_p=0.9, | |
| repetition_penalty=1.2 | |
| ) | |
| # Decode | |
| response = tokenizer.decode(generated) | |
| # Extract assistant's response | |
| if "<ASSISTANT>" in response: | |
| parts = response.split("<ASSISTANT>") | |
| response = parts[-1].strip() | |
| if "<USER>" in response: | |
| response = response.split("<USER>")[0].strip() | |
| response = clean_response(response) | |
| # Save to history | |
| conversation_history.append({ | |
| 'user': user_input, | |
| 'assistant': response | |
| }) | |
| # Save to database | |
| current_conv_id = db.save_conversation(user_input, response) | |
| return response | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"β Error: {str(e)}" | |
| def chat(user_input, history, temperature, max_tokens): | |
| """Chat function for Gradio""" | |
| response = generate_response(user_input, temperature, max_tokens) | |
| history.append((user_input, response)) | |
| return "", history | |
| def feedback_good(): | |
| if current_conv_id > 0: | |
| db.update_feedback(current_conv_id, 1) | |
| return "π Thanks! This helps me improve." | |
| return "" | |
| def feedback_bad(): | |
| if current_conv_id > 0: | |
| db.update_feedback(current_conv_id, -1) | |
| return "π Thanks for the feedback. I'll try to do better." | |
| return "" | |
| def clear_conversation(): | |
| global conversation_history | |
| conversation_history = [] | |
| return [], "" | |
| def retrain(epochs): | |
| """Retrain with good conversations""" | |
| global model, tokenizer | |
| good_convs = db.get_good_conversations() | |
| if not good_convs: | |
| return "No approved conversations yet. Rate some responses first!" | |
| extra_data = "" | |
| for conv in good_convs: | |
| extra_data += f"<USER> {conv['user_input']}\n" | |
| extra_data += f"<ASSISTANT> {conv['assistant_response']}\n\n" | |
| trainer = VedaTrainer() | |
| history = trainer.train(epochs=int(epochs), extra_data=extra_data) | |
| model = trainer.model | |
| tokenizer = trainer.tokenizer | |
| loss = history.history['loss'][-1] | |
| return f"β Training done! Loss: {loss:.4f}, Used {len(good_convs)} conversations" | |
| def get_stats(): | |
| stats = db.get_stats() | |
| return f"""## π Statistics | |
| | Metric | Count | | |
| |--------|-------| | |
| | π¬ Conversations | {stats['total']} | | |
| | π Positive | {stats['positive']} | | |
| | π Negative | {stats['negative']} | | |
| """ | |
| # Create interface | |
| def create_app(): | |
| with gr.Blocks(title="Veda Programming Assistant", theme=gr.themes.Soft()) as app: | |
| gr.Markdown(""" | |
| # ποΈ Veda Programming Assistant | |
| I can **chat**, **write code**, **explain concepts**, and **answer questions**! | |
| """) | |
| with gr.Tabs(): | |
| # Chat Tab | |
| with gr.TabItem("π¬ Chat"): | |
| chatbot = gr.Chatbot(label="Conversation", height=400) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your message", | |
| placeholder="Ask me anything about programming...", | |
| lines=2, | |
| scale=4 | |
| ) | |
| send_btn = gr.Button("Send π€", variant="primary", scale=1) | |
| with gr.Row(): | |
| temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Creativity") | |
| max_tokens = gr.Slider(50, 400, value=200, step=50, label="Response length") | |
| with gr.Row(): | |
| good_btn = gr.Button("π Good", variant="secondary") | |
| bad_btn = gr.Button("π Bad", variant="secondary") | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
| feedback_msg = gr.Textbox(label="", lines=1) | |
| # Events | |
| send_btn.click(chat, [msg, chatbot, temperature, max_tokens], [msg, chatbot]) | |
| msg.submit(chat, [msg, chatbot, temperature, max_tokens], [msg, chatbot]) | |
| good_btn.click(feedback_good, outputs=feedback_msg) | |
| bad_btn.click(feedback_bad, outputs=feedback_msg) | |
| clear_btn.click(clear_conversation, outputs=[chatbot, feedback_msg]) | |
| gr.Markdown("### π‘ Try these:") | |
| gr.Examples( | |
| examples=[ | |
| ["Hello! What can you do?"], | |
| ["What is Python?"], | |
| ["Write a function to calculate factorial"], | |
| ["Explain what recursion is"], | |
| ["How do I read a file in Python?"], | |
| ["Write a bubble sort algorithm"], | |
| ["What's the difference between list and tuple?"], | |
| ], | |
| inputs=msg | |
| ) | |
| # Training Tab | |
| with gr.TabItem("π Training"): | |
| gr.Markdown("### Train on your approved conversations") | |
| train_epochs = gr.Slider(5, 20, value=10, step=1, label="Epochs") | |
| train_btn = gr.Button("π Retrain", variant="primary") | |
| train_output = gr.Markdown() | |
| train_btn.click(retrain, [train_epochs], train_output) | |
| # Stats Tab | |
| with gr.TabItem("π Stats"): | |
| stats_out = gr.Markdown() | |
| refresh_btn = gr.Button("π Refresh") | |
| refresh_btn.click(get_stats, outputs=stats_out) | |
| gr.Markdown("---\n**Veda Programming Assistant** | Learning from every conversation!") | |
| return app | |
| # Main | |
| if __name__ == "__main__": | |
| initialize() | |
| print("\nπ Starting...") | |
| app = create_app() | |
| app.launch(server_name="0.0.0.0", server_port=7860) |