sayalimetkar's picture
Update app.py
caf786d verified
import re
import ast
import operator
import gradio as gr
from ctransformers import AutoModelForCausalLM
# ------------------------------
# MODEL CONFIGURATION
# ------------------------------
MODEL_PATH = "sayalimetkar/quant_model"
SYSTEM_PROMPT = """
You are a highly capable and reliable AI assistant.
You understand and write code in Python, SQL, JavaScript, and other languages accurately.
You also solve math and logic problems step-by-step.
Always give clear, direct, and correct answers without unnecessary explanation.
If the user asks for code, return only properly formatted and working code.
If the user asks for a calculation, show the reasoning and give the exact result.
Always think step by step and explain the reasoning before giving the final answer.
"""
FEW_SHOT_EXAMPLES = {
"math": """... (unchanged) ...
### END OF EXAMPLES
""",
"code": """
User: Write a Python function to return all even numbers from a list.
Assistant:
def filter_even(nums):
return [n for n in nums if n % 2 == 0]
User: Write a Python function to compute the factorial of a number.
Assistant:
def factorial(n):
if n == 0:
return 1
return n * factorial(n - 1)
### END OF EXAMPLES
"""
}
# ------------------------------
# FORMAT PROMPT
# ------------------------------
def format_prompt(system: str, history: list[tuple[str, str]], user_input: str) -> str:
"""
Format the full prompt including system message, few-shot examples, conversation history,
and a strict instruction to prevent extra/unrelated responses.
"""
# --- Detect query type and choose few-shot examples ---
if re.search(r'\b(def|SELECT|INSERT|UPDATE|print|for|while|if|class)\b', user_input, re.I):
few_shot = FEW_SHOT_EXAMPLES["code"]
task_type = "code"
elif is_math_question(user_input):
few_shot = FEW_SHOT_EXAMPLES["math"]
task_type = "math"
else:
few_shot = ""
task_type = "general"
# --- Build base conversation ---
conversation = system.strip() + "\n\n" + few_shot.strip() + "\n\n"
# Add chat history
for user, assistant in history:
conversation += f"User: {user}\nAssistant: {assistant}\n"
# --- Add user input with explicit, single-task instruction ---
if task_type == "code":
tail = (
f"User: {user_input}\n"
"Assistant: Please provide ONLY the corrected or required code block. "
"Do NOT include explanations or any unrelated topics.\n### RESPONSE:\n"
)
elif task_type == "math":
tail = (
f"User: {user_input}\n"
"Assistant: Let's think step by step. Then provide ONLY the final numeric answer, "
"on a new line prefixed by 'Final Answer:'.\n### RESPONSE:\n"
)
else: # general queries
tail = (
f"User: {user_input}\n"
"Assistant: Provide a concise and direct answer. "
"Do NOT add examples, explanations, or unrelated information.\n### RESPONSE:\n"
)
conversation += tail
return conversation
# ------------------------------
# SAFE MATH SOLVER
# ------------------------------
# Supported operators
operators = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: operator.pow,
ast.Mod: operator.mod,
ast.FloorDiv: operator.floordiv
}
def safe_eval(expr):
"""Safely evaluate arithmetic expressions using AST."""
def _eval(node):
if isinstance(node, ast.Expression):
return _eval(node.body)
elif isinstance(node, ast.BinOp):
left = _eval(node.left)
right = _eval(node.right)
return operators[type(node.op)](left, right)
elif isinstance(node, ast.Num):
return node.n
elif isinstance(node, ast.UnaryOp):
if isinstance(node.op, ast.USub):
return -_eval(node.operand)
elif isinstance(node.op, ast.UAdd):
return +_eval(node.operand)
else:
raise ValueError("Unsupported unary operator")
else:
raise ValueError("Unsupported expression")
node = ast.parse(expr, mode='eval')
return _eval(node)
def is_math_question(user_input):
return bool(re.search(r'(\d+[\s\+\-\*/^()]|\bseries\b|\baverage\b|\bpercent|\bspeed|\btime|\bdistance\b)', user_input.lower()))
def solve_math(user_input):
"""Solve any arithmetic expression safely."""
try:
# Keep only numbers, operators, parentheses
expr = re.sub(r'[^0-9+\-*/().^%]', '', user_input)
if not expr:
return None
# Replace ^ with ** for exponentiation
expr = expr.replace('^', '**')
result = safe_eval(expr)
return str(result)
except:
return None
# ------------------------------
# LOAD MODEL
# ------------------------------
model = AutoModelForCausalLM.from_pretrained(
"sayalimetkar/quant_model",
model_type="mistral",
temperature=0.2,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
context_length=4096,
max_new_tokens=800
)
# ------------------------------
# STREAM REPLY FUNCTION
# ------------------------------
stop_flag = {"stop": False}
def stream_reply(user_input, history):
stop_flag["stop"] = False
# 1️⃣ Handle direct arithmetic
if is_math_question(user_input):
math_answer = solve_math(user_input)
if math_answer:
cleaned = re.sub(r"(?i)(User:|Assistant:)", "", partial).strip()
yield history + [(user_input, cleaned)]
return
# 2️⃣ Let model handle reasoning or coding
prompt = format_prompt(SYSTEM_PROMPT, history, user_input)
partial = ""
for token in model(prompt, stream=True):
if stop_flag["stop"]:
break
partial += token
# Clean prefixes
cleaned = re.sub(r"(?i)(User:|Assistant:)", "", partial).strip()
yield history + [(user_input, cleaned)]
# ------------------------------
# GRADIO UI
# ------------------------------
with gr.Blocks() as demo:
chatbot = gr.Chatbot(label="Chatbot")
msg = gr.Textbox(label="Your message")
send = gr.Button("Send")
stop = gr.Button("🛑 Stop Response")
reset = gr.Button("🔄 Reset Chat")
# Add message to history
def user_submit(user_message, history):
return "", history + [(user_message, "")]
# Reset chat
def reset_chat():
return []
# Stop current generation
def stop_generation():
stop_flag["stop"] = True
return None
# UI Event Handlers
msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
stream_reply, [msg, chatbot], chatbot
)
send.click(user_submit, [msg, chatbot], [msg, chatbot]).then(
stream_reply, [msg, chatbot], chatbot
)
reset.click(reset_chat, outputs=chatbot)
stop.click(stop_generation, None, None)
# ------------------------------
# LAUNCH APP
# ------------------------------
demo.launch(server_name="0.0.0.0", server_port=7860)