Spaces:
Running
Running
| """Gradio App for Veda Programming Assistant - Gradio 6.x compatible (with math solving)""" | |
| import gradio as gr | |
| import tensorflow as tf | |
| import os | |
| import json | |
| import re | |
| import ast | |
| import operator as op | |
| from model import VedaProgrammingLLM | |
| from tokenizer import VedaTokenizer | |
| from database import db | |
| from train import VedaTrainer | |
| from config import MODEL_DIR | |
| # --------- Globals ---------- | |
| model = None | |
| tokenizer = None | |
| conversation_history = [] # used for building prompt context for the model | |
| current_conv_id = -1 | |
| # --------- Helpers (Gradio message parsing) ---------- | |
| def extract_text(message): | |
| """ | |
| Convert Gradio multimodal / messages objects -> plain string. | |
| Handles: | |
| - str | |
| - dict: {"text": "..."} or {"content": "..."} | |
| - list of parts: [{"type":"text","text":"..."}] | |
| """ | |
| if message is None: | |
| return "" | |
| if isinstance(message, str): | |
| return message | |
| if isinstance(message, dict): | |
| if "text" in message: | |
| return str(message.get("text", "")) | |
| if "content" in message: | |
| return extract_text(message["content"]) | |
| return "" | |
| if isinstance(message, list): | |
| parts = [] | |
| for part in message: | |
| if isinstance(part, dict) and part.get("type") == "text": | |
| parts.append(str(part.get("text", ""))) | |
| elif isinstance(part, str): | |
| parts.append(part) | |
| return "".join(parts).strip() | |
| return str(message) | |
| def ensure_messages_history(history): | |
| """ | |
| Ensure Chatbot history is ALWAYS messages format: | |
| [{"role":"user","content":"..."}, {"role":"assistant","content":"..."}] | |
| Also converts old tuple format [(user, bot), ...] -> messages. | |
| """ | |
| if history is None: | |
| return [] | |
| # Already messages format | |
| if ( | |
| len(history) > 0 | |
| and isinstance(history[0], dict) | |
| and "role" in history[0] | |
| and "content" in history[0] | |
| ): | |
| fixed = [] | |
| for m in history: | |
| fixed.append({"role": m["role"], "content": extract_text(m["content"])}) | |
| return fixed | |
| # Tuple/pair format -> messages format | |
| fixed = [] | |
| for pair in history: | |
| if isinstance(pair, (list, tuple)) and len(pair) == 2: | |
| fixed.append({"role": "user", "content": extract_text(pair[0])}) | |
| fixed.append({"role": "assistant", "content": extract_text(pair[1])}) | |
| return fixed | |
| # --------- Safe Math Solver ---------- | |
| _ALLOWED_OPS = { | |
| ast.Add: op.add, | |
| ast.Sub: op.sub, | |
| ast.Mult: op.mul, | |
| ast.Div: op.truediv, | |
| ast.Mod: op.mod, | |
| ast.Pow: op.pow, | |
| ast.USub: op.neg, | |
| ast.UAdd: op.pos, | |
| } | |
| def safe_eval_math(expr: str): | |
| """ | |
| Safely evaluate arithmetic expression (no variables, no function calls). | |
| Supports: + - * / % ** and parentheses, integers/floats. | |
| """ | |
| node = ast.parse(expr, mode="eval").body | |
| def _eval(n): | |
| if isinstance(n, ast.Constant) and isinstance(n.value, (int, float)): | |
| return n.value | |
| if isinstance(n, ast.BinOp) and type(n.op) in _ALLOWED_OPS: | |
| return _ALLOWED_OPS[type(n.op)](_eval(n.left), _eval(n.right)) | |
| if isinstance(n, ast.UnaryOp) and type(n.op) in _ALLOWED_OPS: | |
| return _ALLOWED_OPS[type(n.op)](_eval(n.operand)) | |
| raise ValueError("Unsupported expression") | |
| return _eval(node) | |
| def try_math_answer(user_text: str): | |
| """ | |
| If user text looks like a pure math expression, return computed answer as string. | |
| Otherwise return None. | |
| Examples: | |
| "2+2=?" -> "4" | |
| "2^5" -> "32" | |
| "(10+5)/3" -> "5" | |
| """ | |
| if not user_text: | |
| return None | |
| # Normalize common decorations | |
| s = user_text.strip() | |
| s = s.replace("=", "").replace("?", "").strip() | |
| s = s.replace("^", "**") # allow ^ as power | |
| # Only allow digits/operators/parentheses/dots/spaces | |
| if not re.fullmatch(r"[0-9\.\s\+\-\*\/\(\)%]+", s): | |
| return None | |
| try: | |
| val = safe_eval_math(s) | |
| # pretty formatting: 4.0 -> 4 | |
| if isinstance(val, float) and val.is_integer(): | |
| val = int(val) | |
| return str(val) | |
| except Exception: | |
| return None | |
| # --------- Model init ---------- | |
| def initialize(): | |
| """Initialize the assistant (load if exists, else train once).""" | |
| global model, tokenizer | |
| print("Initializing Veda Programming Assistant...") | |
| config_path = os.path.join(MODEL_DIR, "config.json") | |
| if os.path.exists(config_path): | |
| print("Loading existing model...") | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| tokenizer = VedaTokenizer() | |
| tokenizer.load(os.path.join(MODEL_DIR, "tokenizer.json")) | |
| model = VedaProgrammingLLM( | |
| vocab_size=config["vocab_size"], | |
| max_length=config["max_length"], | |
| d_model=config["d_model"], | |
| num_heads=config["num_heads"], | |
| num_layers=config["num_layers"], | |
| ff_dim=config["ff_dim"], | |
| ) | |
| dummy = tf.zeros((1, config["max_length"]), dtype=tf.int32) | |
| model(dummy) | |
| model.load_weights(os.path.join(MODEL_DIR, "weights.h5")) | |
| print("Model loaded!") | |
| else: | |
| print("No saved model found. Training a new model...") | |
| trainer = VedaTrainer() | |
| trainer.train(epochs=15) | |
| model = trainer.model | |
| tokenizer = trainer.tokenizer | |
| print("Model trained!") | |
| def clean_response(text: str) -> str: | |
| """Clean the response text for display.""" | |
| text = text.replace("<CODE>", "\n```python\n") | |
| text = text.replace("<ENDCODE>", "\n```\n") | |
| for token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]: | |
| text = text.replace(token, "") | |
| lines = text.split("\n") | |
| cleaned = [] | |
| empty_count = 0 | |
| for line in lines: | |
| if line.strip() == "": | |
| empty_count += 1 | |
| if empty_count <= 2: | |
| cleaned.append(line) | |
| else: | |
| empty_count = 0 | |
| cleaned.append(line) | |
| return "\n".join(cleaned).strip() | |
| def generate_response(user_input: str, temperature: float = 0.7, max_tokens: int = 200) -> str: | |
| """Generate a response from the model OR solve math deterministically.""" | |
| global current_conv_id, conversation_history | |
| # Convert Gradio multimodal -> text | |
| user_input = extract_text(user_input).strip() | |
| if not user_input: | |
| return "Please type a message!" | |
| # 1) Try math solver first | |
| math_ans = try_math_answer(user_input) | |
| if math_ans is not None: | |
| # Save conversation too (optional) | |
| conversation_history.append({"user": user_input, "assistant": math_ans}) | |
| current_conv_id = db.save_conversation(user_input, math_ans) | |
| return math_ans | |
| # 2) Otherwise use model | |
| if model is None: | |
| return "Model is loading, please wait..." | |
| try: | |
| context = "" | |
| for msg in conversation_history[-3:]: | |
| context += f"<USER> {msg['user']}\n<ASSISTANT> {msg['assistant']}\n" | |
| prompt = context + f"<USER> {user_input}\n<ASSISTANT>" | |
| tokens = tokenizer.encode(prompt) | |
| if len(tokens) > model.max_length - max_tokens: | |
| tokens = tokens[-(model.max_length - max_tokens):] | |
| generated = model.generate( | |
| tokens, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=50, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| ) | |
| response = tokenizer.decode(generated) | |
| if "<ASSISTANT>" in response: | |
| response = response.split("<ASSISTANT>")[-1].strip() | |
| if "<USER>" in response: | |
| response = response.split("<USER>")[0].strip() | |
| response = clean_response(response) | |
| if not response: | |
| response = "I'm not sure how to respond to that. Could you try rephrasing?" | |
| conversation_history.append({"user": user_input, "assistant": response}) | |
| current_conv_id = db.save_conversation(user_input, response) | |
| return response | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error: {str(e)}" | |
| # --------- Gradio handlers ---------- | |
| def respond(message, history, temperature, max_tokens): | |
| """Always return messages-format history.""" | |
| history = ensure_messages_history(history) | |
| user_text = extract_text(message).strip() | |
| if not user_text: | |
| return "", history | |
| bot_message = generate_response(user_text, temperature, max_tokens) | |
| history.append({"role": "user", "content": user_text}) | |
| history.append({"role": "assistant", "content": bot_message}) | |
| return "", history | |
| def feedback_good(): | |
| global current_conv_id | |
| if current_conv_id > 0: | |
| db.update_feedback(current_conv_id, 1) | |
| return "π Thanks for the positive feedback!" | |
| return "No conversation to rate yet." | |
| def feedback_bad(): | |
| global current_conv_id | |
| if current_conv_id > 0: | |
| db.update_feedback(current_conv_id, -1) | |
| return "π Thanks! I'll try to improve." | |
| return "No conversation to rate yet." | |
| def clear_chat(): | |
| global conversation_history | |
| conversation_history = [] | |
| return [], "Chat cleared." | |
| def retrain(epochs): | |
| """Retrain with good conversations.""" | |
| global model, tokenizer | |
| good_convs = db.get_good_conversations() | |
| if not good_convs: | |
| return "No approved conversations yet. Rate some responses as 'Good' first!" | |
| extra_data = "" | |
| for conv in good_convs: | |
| extra_data += f"<USER> {conv['user_input']}\n" | |
| extra_data += f"<ASSISTANT> {conv['assistant_response']}\n\n" | |
| trainer = VedaTrainer() | |
| history = trainer.train(epochs=int(epochs), extra_data=extra_data) | |
| model = trainer.model | |
| tokenizer = trainer.tokenizer | |
| loss = history.history["loss"][-1] | |
| return f"β Training complete! Loss: {loss:.4f}, Used {len(good_convs)} conversations" | |
| def get_stats(): | |
| stats = db.get_stats() | |
| return f"""## π Statistics | |
| | Metric | Count | | |
| |--------|-------| | |
| | π¬ Total Conversations | {stats['total']} | | |
| | π Positive Feedback | {stats['positive']} | | |
| | π Negative Feedback | {stats['negative']} | | |
| """ | |
| # --------- Startup ---------- | |
| print("Starting initialization...") | |
| initialize() | |
| print("Initialization complete!") | |
| # --------- UI ---------- | |
| with gr.Blocks(title="Veda Programming Assistant") as demo: | |
| gr.Markdown( | |
| """ | |
| # ποΈ Veda Programming Assistant | |
| Now supports **math** (e.g., `2+2=?`, `(10+5)/3`, `2^5`) plus coding/chatting. | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("π¬ Chat"): | |
| chatbot = gr.Chatbot(label="Conversation", height=400, value=[]) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your message", | |
| placeholder="Ask me anything about programming... or type math like 2+2=?", | |
| lines=2, | |
| scale=4, | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| temperature = gr.Slider(0.1, 1.5, 0.7, step=0.1, label="Creativity") | |
| max_tokens = gr.Slider(50, 400, 200, step=50, label="Response length") | |
| with gr.Row(): | |
| good_btn = gr.Button("π Good", variant="secondary") | |
| bad_btn = gr.Button("π Bad", variant="secondary") | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
| feedback_msg = gr.Textbox(label="Status", lines=1, interactive=False) | |
| send_btn.click(respond, inputs=[msg, chatbot, temperature, max_tokens], outputs=[msg, chatbot]) | |
| msg.submit(respond, inputs=[msg, chatbot, temperature, max_tokens], outputs=[msg, chatbot]) | |
| good_btn.click(feedback_good, outputs=feedback_msg) | |
| bad_btn.click(feedback_bad, outputs=feedback_msg) | |
| clear_btn.click(clear_chat, outputs=[chatbot, feedback_msg]) | |
| gr.Markdown("### π‘ Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["2+2=?"], | |
| ["(10+5)/3"], | |
| ["2^8"], | |
| ["What is Python?"], | |
| ["Write a function to calculate factorial"], | |
| ["Explain recursion"], | |
| ], | |
| inputs=msg, | |
| ) | |
| with gr.TabItem("π Training"): | |
| gr.Markdown( | |
| """ | |
| ### Improve the Assistant | |
| 1. Chat with the assistant | |
| 2. Rate good responses with π | |
| 3. Click "Retrain Model" to learn from good conversations | |
| """ | |
| ) | |
| train_epochs = gr.Slider(5, 20, 10, step=1, label="Training Epochs") | |
| train_btn = gr.Button("π Retrain Model", variant="primary") | |
| train_output = gr.Markdown() | |
| train_btn.click(retrain, inputs=[train_epochs], outputs=train_output) | |
| with gr.TabItem("π Statistics"): | |
| stats_out = gr.Markdown() | |
| refresh_btn = gr.Button("π Refresh Statistics") | |
| refresh_btn.click(get_stats, outputs=stats_out) | |
| gr.Markdown("---\n**Veda Programming Assistant**") | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |