vedaco's picture
Update app.py
fa4ebed verified
"""Gradio App for Veda Programming Assistant - Gradio 6.x compatible (with math solving)"""
import gradio as gr
import tensorflow as tf
import os
import json
import re
import ast
import operator as op
from model import VedaProgrammingLLM
from tokenizer import VedaTokenizer
from database import db
from train import VedaTrainer
from config import MODEL_DIR
# --------- Globals ----------
model = None
tokenizer = None
conversation_history = [] # used for building prompt context for the model
current_conv_id = -1
# --------- Helpers (Gradio message parsing) ----------
def extract_text(message):
"""
Convert Gradio multimodal / messages objects -> plain string.
Handles:
- str
- dict: {"text": "..."} or {"content": "..."}
- list of parts: [{"type":"text","text":"..."}]
"""
if message is None:
return ""
if isinstance(message, str):
return message
if isinstance(message, dict):
if "text" in message:
return str(message.get("text", ""))
if "content" in message:
return extract_text(message["content"])
return ""
if isinstance(message, list):
parts = []
for part in message:
if isinstance(part, dict) and part.get("type") == "text":
parts.append(str(part.get("text", "")))
elif isinstance(part, str):
parts.append(part)
return "".join(parts).strip()
return str(message)
def ensure_messages_history(history):
"""
Ensure Chatbot history is ALWAYS messages format:
[{"role":"user","content":"..."}, {"role":"assistant","content":"..."}]
Also converts old tuple format [(user, bot), ...] -> messages.
"""
if history is None:
return []
# Already messages format
if (
len(history) > 0
and isinstance(history[0], dict)
and "role" in history[0]
and "content" in history[0]
):
fixed = []
for m in history:
fixed.append({"role": m["role"], "content": extract_text(m["content"])})
return fixed
# Tuple/pair format -> messages format
fixed = []
for pair in history:
if isinstance(pair, (list, tuple)) and len(pair) == 2:
fixed.append({"role": "user", "content": extract_text(pair[0])})
fixed.append({"role": "assistant", "content": extract_text(pair[1])})
return fixed
# --------- Safe Math Solver ----------
_ALLOWED_OPS = {
ast.Add: op.add,
ast.Sub: op.sub,
ast.Mult: op.mul,
ast.Div: op.truediv,
ast.Mod: op.mod,
ast.Pow: op.pow,
ast.USub: op.neg,
ast.UAdd: op.pos,
}
def safe_eval_math(expr: str):
"""
Safely evaluate arithmetic expression (no variables, no function calls).
Supports: + - * / % ** and parentheses, integers/floats.
"""
node = ast.parse(expr, mode="eval").body
def _eval(n):
if isinstance(n, ast.Constant) and isinstance(n.value, (int, float)):
return n.value
if isinstance(n, ast.BinOp) and type(n.op) in _ALLOWED_OPS:
return _ALLOWED_OPS[type(n.op)](_eval(n.left), _eval(n.right))
if isinstance(n, ast.UnaryOp) and type(n.op) in _ALLOWED_OPS:
return _ALLOWED_OPS[type(n.op)](_eval(n.operand))
raise ValueError("Unsupported expression")
return _eval(node)
def try_math_answer(user_text: str):
"""
If user text looks like a pure math expression, return computed answer as string.
Otherwise return None.
Examples:
"2+2=?" -> "4"
"2^5" -> "32"
"(10+5)/3" -> "5"
"""
if not user_text:
return None
# Normalize common decorations
s = user_text.strip()
s = s.replace("=", "").replace("?", "").strip()
s = s.replace("^", "**") # allow ^ as power
# Only allow digits/operators/parentheses/dots/spaces
if not re.fullmatch(r"[0-9\.\s\+\-\*\/\(\)%]+", s):
return None
try:
val = safe_eval_math(s)
# pretty formatting: 4.0 -> 4
if isinstance(val, float) and val.is_integer():
val = int(val)
return str(val)
except Exception:
return None
# --------- Model init ----------
def initialize():
"""Initialize the assistant (load if exists, else train once)."""
global model, tokenizer
print("Initializing Veda Programming Assistant...")
config_path = os.path.join(MODEL_DIR, "config.json")
if os.path.exists(config_path):
print("Loading existing model...")
with open(config_path, "r") as f:
config = json.load(f)
tokenizer = VedaTokenizer()
tokenizer.load(os.path.join(MODEL_DIR, "tokenizer.json"))
model = VedaProgrammingLLM(
vocab_size=config["vocab_size"],
max_length=config["max_length"],
d_model=config["d_model"],
num_heads=config["num_heads"],
num_layers=config["num_layers"],
ff_dim=config["ff_dim"],
)
dummy = tf.zeros((1, config["max_length"]), dtype=tf.int32)
model(dummy)
model.load_weights(os.path.join(MODEL_DIR, "weights.h5"))
print("Model loaded!")
else:
print("No saved model found. Training a new model...")
trainer = VedaTrainer()
trainer.train(epochs=15)
model = trainer.model
tokenizer = trainer.tokenizer
print("Model trained!")
def clean_response(text: str) -> str:
"""Clean the response text for display."""
text = text.replace("<CODE>", "\n```python\n")
text = text.replace("<ENDCODE>", "\n```\n")
for token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]:
text = text.replace(token, "")
lines = text.split("\n")
cleaned = []
empty_count = 0
for line in lines:
if line.strip() == "":
empty_count += 1
if empty_count <= 2:
cleaned.append(line)
else:
empty_count = 0
cleaned.append(line)
return "\n".join(cleaned).strip()
def generate_response(user_input: str, temperature: float = 0.7, max_tokens: int = 200) -> str:
"""Generate a response from the model OR solve math deterministically."""
global current_conv_id, conversation_history
# Convert Gradio multimodal -> text
user_input = extract_text(user_input).strip()
if not user_input:
return "Please type a message!"
# 1) Try math solver first
math_ans = try_math_answer(user_input)
if math_ans is not None:
# Save conversation too (optional)
conversation_history.append({"user": user_input, "assistant": math_ans})
current_conv_id = db.save_conversation(user_input, math_ans)
return math_ans
# 2) Otherwise use model
if model is None:
return "Model is loading, please wait..."
try:
context = ""
for msg in conversation_history[-3:]:
context += f"<USER> {msg['user']}\n<ASSISTANT> {msg['assistant']}\n"
prompt = context + f"<USER> {user_input}\n<ASSISTANT>"
tokens = tokenizer.encode(prompt)
if len(tokens) > model.max_length - max_tokens:
tokens = tokens[-(model.max_length - max_tokens):]
generated = model.generate(
tokens,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
)
response = tokenizer.decode(generated)
if "<ASSISTANT>" in response:
response = response.split("<ASSISTANT>")[-1].strip()
if "<USER>" in response:
response = response.split("<USER>")[0].strip()
response = clean_response(response)
if not response:
response = "I'm not sure how to respond to that. Could you try rephrasing?"
conversation_history.append({"user": user_input, "assistant": response})
current_conv_id = db.save_conversation(user_input, response)
return response
except Exception as e:
import traceback
traceback.print_exc()
return f"Error: {str(e)}"
# --------- Gradio handlers ----------
def respond(message, history, temperature, max_tokens):
"""Always return messages-format history."""
history = ensure_messages_history(history)
user_text = extract_text(message).strip()
if not user_text:
return "", history
bot_message = generate_response(user_text, temperature, max_tokens)
history.append({"role": "user", "content": user_text})
history.append({"role": "assistant", "content": bot_message})
return "", history
def feedback_good():
global current_conv_id
if current_conv_id > 0:
db.update_feedback(current_conv_id, 1)
return "πŸ‘ Thanks for the positive feedback!"
return "No conversation to rate yet."
def feedback_bad():
global current_conv_id
if current_conv_id > 0:
db.update_feedback(current_conv_id, -1)
return "πŸ‘Ž Thanks! I'll try to improve."
return "No conversation to rate yet."
def clear_chat():
global conversation_history
conversation_history = []
return [], "Chat cleared."
def retrain(epochs):
"""Retrain with good conversations."""
global model, tokenizer
good_convs = db.get_good_conversations()
if not good_convs:
return "No approved conversations yet. Rate some responses as 'Good' first!"
extra_data = ""
for conv in good_convs:
extra_data += f"<USER> {conv['user_input']}\n"
extra_data += f"<ASSISTANT> {conv['assistant_response']}\n\n"
trainer = VedaTrainer()
history = trainer.train(epochs=int(epochs), extra_data=extra_data)
model = trainer.model
tokenizer = trainer.tokenizer
loss = history.history["loss"][-1]
return f"βœ… Training complete! Loss: {loss:.4f}, Used {len(good_convs)} conversations"
def get_stats():
stats = db.get_stats()
return f"""## πŸ“Š Statistics
| Metric | Count |
|--------|-------|
| πŸ’¬ Total Conversations | {stats['total']} |
| πŸ‘ Positive Feedback | {stats['positive']} |
| πŸ‘Ž Negative Feedback | {stats['negative']} |
"""
# --------- Startup ----------
print("Starting initialization...")
initialize()
print("Initialization complete!")
# --------- UI ----------
with gr.Blocks(title="Veda Programming Assistant") as demo:
gr.Markdown(
"""
# πŸ•‰οΈ Veda Programming Assistant
Now supports **math** (e.g., `2+2=?`, `(10+5)/3`, `2^5`) plus coding/chatting.
"""
)
with gr.Tabs():
with gr.TabItem("πŸ’¬ Chat"):
chatbot = gr.Chatbot(label="Conversation", height=400, value=[])
with gr.Row():
msg = gr.Textbox(
label="Your message",
placeholder="Ask me anything about programming... or type math like 2+2=?",
lines=2,
scale=4,
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
temperature = gr.Slider(0.1, 1.5, 0.7, step=0.1, label="Creativity")
max_tokens = gr.Slider(50, 400, 200, step=50, label="Response length")
with gr.Row():
good_btn = gr.Button("πŸ‘ Good", variant="secondary")
bad_btn = gr.Button("πŸ‘Ž Bad", variant="secondary")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
feedback_msg = gr.Textbox(label="Status", lines=1, interactive=False)
send_btn.click(respond, inputs=[msg, chatbot, temperature, max_tokens], outputs=[msg, chatbot])
msg.submit(respond, inputs=[msg, chatbot, temperature, max_tokens], outputs=[msg, chatbot])
good_btn.click(feedback_good, outputs=feedback_msg)
bad_btn.click(feedback_bad, outputs=feedback_msg)
clear_btn.click(clear_chat, outputs=[chatbot, feedback_msg])
gr.Markdown("### πŸ’‘ Examples")
gr.Examples(
examples=[
["2+2=?"],
["(10+5)/3"],
["2^8"],
["What is Python?"],
["Write a function to calculate factorial"],
["Explain recursion"],
],
inputs=msg,
)
with gr.TabItem("πŸŽ“ Training"):
gr.Markdown(
"""
### Improve the Assistant
1. Chat with the assistant
2. Rate good responses with πŸ‘
3. Click "Retrain Model" to learn from good conversations
"""
)
train_epochs = gr.Slider(5, 20, 10, step=1, label="Training Epochs")
train_btn = gr.Button("πŸ”„ Retrain Model", variant="primary")
train_output = gr.Markdown()
train_btn.click(retrain, inputs=[train_epochs], outputs=train_output)
with gr.TabItem("πŸ“Š Statistics"):
stats_out = gr.Markdown()
refresh_btn = gr.Button("πŸ”„ Refresh Statistics")
refresh_btn.click(get_stats, outputs=stats_out)
gr.Markdown("---\n**Veda Programming Assistant**")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)