vedaco's picture
Update app.py
108c40e verified
raw
history blame
9.4 kB
"""Gradio App - REPLACED with chat interface"""
import gradio as gr
import tensorflow as tf
import os
import json
from model import VedaProgrammingLLM
from tokenizer import VedaTokenizer
from database import db
from train import VedaTrainer
from config import MODEL_DIR
# Global state
model = None
tokenizer = None
conversation_history = []
current_conv_id = -1
def initialize():
"""Initialize the assistant"""
global model, tokenizer
print("πŸ•‰οΈ Initializing Veda Programming Assistant...")
config_path = os.path.join(MODEL_DIR, "config.json")
if os.path.exists(config_path):
print("Loading existing model...")
with open(config_path, 'r') as f:
config = json.load(f)
tokenizer = VedaTokenizer()
tokenizer.load(os.path.join(MODEL_DIR, "tokenizer.json"))
model = VedaProgrammingLLM(
vocab_size=config['vocab_size'],
max_length=config['max_length'],
d_model=config['d_model'],
num_heads=config['num_heads'],
num_layers=config['num_layers'],
ff_dim=config['ff_dim']
)
dummy = tf.zeros((1, config['max_length']), dtype=tf.int32)
model(dummy)
model.load_weights(os.path.join(MODEL_DIR, "weights.h5"))
print("βœ… Model loaded!")
else:
print("Training new model (this takes a few minutes)...")
trainer = VedaTrainer()
trainer.train(epochs=15)
model = trainer.model
tokenizer = trainer.tokenizer
print("βœ… Model trained!")
def clean_response(text: str) -> str:
"""Clean the response"""
# Handle code blocks
text = text.replace("<CODE>", "\n```python\n")
text = text.replace("<ENDCODE>", "\n```\n")
# Remove special tokens
for token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]:
text = text.replace(token, "")
# Clean whitespace
lines = text.split('\n')
cleaned = []
empty_count = 0
for line in lines:
if line.strip() == '':
empty_count += 1
if empty_count <= 2:
cleaned.append(line)
else:
empty_count = 0
cleaned.append(line)
return '\n'.join(cleaned).strip()
def generate_response(user_input: str, temperature: float = 0.7,
max_tokens: int = 200) -> str:
"""Generate a response"""
global current_conv_id
if model is None:
return "⏳ Model is loading..."
if not user_input.strip():
return "Please type a message!"
try:
# Build context from history (last 3 exchanges)
context = ""
for msg in conversation_history[-3:]:
context += f"<USER> {msg['user']}\n<ASSISTANT> {msg['assistant']}\n"
# Add current input
prompt = context + f"<USER> {user_input}\n<ASSISTANT>"
# Encode
tokens = tokenizer.encode(prompt)
# Truncate if too long
if len(tokens) > model.max_length - max_tokens:
tokens = tokens[-(model.max_length - max_tokens):]
# Generate
generated = model.generate(
tokens,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=50,
top_p=0.9,
repetition_penalty=1.2
)
# Decode
response = tokenizer.decode(generated)
# Extract assistant's response
if "<ASSISTANT>" in response:
parts = response.split("<ASSISTANT>")
response = parts[-1].strip()
if "<USER>" in response:
response = response.split("<USER>")[0].strip()
response = clean_response(response)
# Save to history
conversation_history.append({
'user': user_input,
'assistant': response
})
# Save to database
current_conv_id = db.save_conversation(user_input, response)
return response
except Exception as e:
import traceback
traceback.print_exc()
return f"❌ Error: {str(e)}"
def chat(user_input, history, temperature, max_tokens):
"""Chat function for Gradio"""
response = generate_response(user_input, temperature, max_tokens)
history.append((user_input, response))
return "", history
def feedback_good():
if current_conv_id > 0:
db.update_feedback(current_conv_id, 1)
return "πŸ‘ Thanks! This helps me improve."
return ""
def feedback_bad():
if current_conv_id > 0:
db.update_feedback(current_conv_id, -1)
return "πŸ‘Ž Thanks for the feedback. I'll try to do better."
return ""
def clear_conversation():
global conversation_history
conversation_history = []
return [], ""
def retrain(epochs):
"""Retrain with good conversations"""
global model, tokenizer
good_convs = db.get_good_conversations()
if not good_convs:
return "No approved conversations yet. Rate some responses first!"
extra_data = ""
for conv in good_convs:
extra_data += f"<USER> {conv['user_input']}\n"
extra_data += f"<ASSISTANT> {conv['assistant_response']}\n\n"
trainer = VedaTrainer()
history = trainer.train(epochs=int(epochs), extra_data=extra_data)
model = trainer.model
tokenizer = trainer.tokenizer
loss = history.history['loss'][-1]
return f"βœ… Training done! Loss: {loss:.4f}, Used {len(good_convs)} conversations"
def get_stats():
stats = db.get_stats()
return f"""## πŸ“Š Statistics
| Metric | Count |
|--------|-------|
| πŸ’¬ Conversations | {stats['total']} |
| πŸ‘ Positive | {stats['positive']} |
| πŸ‘Ž Negative | {stats['negative']} |
"""
# Create interface
def create_app():
with gr.Blocks(title="Veda Programming Assistant", theme=gr.themes.Soft()) as app:
gr.Markdown("""
# πŸ•‰οΈ Veda Programming Assistant
I can **chat**, **write code**, **explain concepts**, and **answer questions**!
""")
with gr.Tabs():
# Chat Tab
with gr.TabItem("πŸ’¬ Chat"):
chatbot = gr.Chatbot(label="Conversation", height=400)
with gr.Row():
msg = gr.Textbox(
label="Your message",
placeholder="Ask me anything about programming...",
lines=2,
scale=4
)
send_btn = gr.Button("Send πŸ“€", variant="primary", scale=1)
with gr.Row():
temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Creativity")
max_tokens = gr.Slider(50, 400, value=200, step=50, label="Response length")
with gr.Row():
good_btn = gr.Button("πŸ‘ Good", variant="secondary")
bad_btn = gr.Button("πŸ‘Ž Bad", variant="secondary")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
feedback_msg = gr.Textbox(label="", lines=1)
# Events
send_btn.click(chat, [msg, chatbot, temperature, max_tokens], [msg, chatbot])
msg.submit(chat, [msg, chatbot, temperature, max_tokens], [msg, chatbot])
good_btn.click(feedback_good, outputs=feedback_msg)
bad_btn.click(feedback_bad, outputs=feedback_msg)
clear_btn.click(clear_conversation, outputs=[chatbot, feedback_msg])
gr.Markdown("### πŸ’‘ Try these:")
gr.Examples(
examples=[
["Hello! What can you do?"],
["What is Python?"],
["Write a function to calculate factorial"],
["Explain what recursion is"],
["How do I read a file in Python?"],
["Write a bubble sort algorithm"],
["What's the difference between list and tuple?"],
],
inputs=msg
)
# Training Tab
with gr.TabItem("πŸŽ“ Training"):
gr.Markdown("### Train on your approved conversations")
train_epochs = gr.Slider(5, 20, value=10, step=1, label="Epochs")
train_btn = gr.Button("πŸ”„ Retrain", variant="primary")
train_output = gr.Markdown()
train_btn.click(retrain, [train_epochs], train_output)
# Stats Tab
with gr.TabItem("πŸ“Š Stats"):
stats_out = gr.Markdown()
refresh_btn = gr.Button("πŸ”„ Refresh")
refresh_btn.click(get_stats, outputs=stats_out)
gr.Markdown("---\n**Veda Programming Assistant** | Learning from every conversation!")
return app
# Main
if __name__ == "__main__":
initialize()
print("\nπŸš€ Starting...")
app = create_app()
app.launch(server_name="0.0.0.0", server_port=7860)