Vivek16 commited on
Commit
afa0209
Β·
verified Β·
1 Parent(s): 268d8d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -30
app.py CHANGED
@@ -3,19 +3,19 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
 
6
- # --- Configuration ---
7
- # Your model repository ID
8
  BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
9
  ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU"
10
 
11
- # Define the instruction template components for conversational turn-taking
12
- # FIX: Relaxed the prompt to allow for general chat/greetings
13
- SYSTEM_INSTRUCTION = "You are a helpful and polite math assistant. Your main goal is to solve math problems, but you can also answer general greetings or small talk."
14
  USER_TEMPLATE = "<|user|>\n{}</s>"
15
  ASSISTANT_TEMPLATE = "<|assistant|>\n{}</s>"
16
 
17
 
18
- # --- Model Loading Function ---
19
  def load_model():
20
  """Loads the base model and merges the LoRA adapters."""
21
  print("Loading base model...")
@@ -37,25 +37,17 @@ def load_model():
37
  print("Model loaded and merged successfully!")
38
  return tokenizer, model
39
 
40
- # Load the model outside the prediction function for efficiency
41
  tokenizer, model = load_model()
42
 
43
 
44
- # --- Prediction Function for gr.ChatInterface ---
45
  def generate_response(message, history):
46
  """Generates a response using chat history and the fine-tuned model."""
47
 
48
- # 1. Build the full prompt using the TinyLlama Chat template
49
-
50
  # Start with the system instruction
51
  full_prompt = f"<|system|>\n{SYSTEM_INSTRUCTION}</s>\n"
52
 
53
- # --- FEW-SHOT EXAMPLE to handle greetings (FIXED "ciao" and "hi" issue) ---
54
- # This teaches the model how to handle a simple non-math exchange by providing a pattern.
55
- full_prompt += "<|user|>\nHello!</s>\n<|assistant|>\nHello! How can I assist you with a math problem today?</s>\n"
56
- # -------------------------------------------------------------------------
57
-
58
- # Append the actual chat history from the Gradio interface
59
  for user_msg, assistant_msg in history:
60
  full_prompt += USER_TEMPLATE.format(user_msg) + "\n"
61
  full_prompt += ASSISTANT_TEMPLATE.format(assistant_msg) + "\n"
@@ -63,11 +55,9 @@ def generate_response(message, history):
63
  # Append the current user message and the start of the assistant's turn
64
  full_prompt += USER_TEMPLATE.format(message) + "\n"
65
  full_prompt += "<|assistant|>\n"
66
-
67
- # 2. Tokenize the input
68
- inputs = tokenizer(full_prompt, return_tensors="pt")
69
 
70
- # 3. Generate the response (on CPU)
 
71
  with torch.no_grad():
72
  output_tokens = model.generate(
73
  **inputs,
@@ -78,13 +68,11 @@ def generate_response(message, history):
78
  pad_token_id=tokenizer.eos_token_id
79
  )
80
 
81
- # 4. Decode the output and extract only the new response
82
  generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
83
 
84
- # Find the start of the final assistant's turn in the output
85
  response_start = generated_text.rfind('<|assistant|>')
86
  if response_start != -1:
87
- # Get the text after <|assistant|> and strip the trailing </s>
88
  raw_response = generated_text[response_start + len('<|assistant|>'):].strip()
89
  assistant_response = raw_response.split('</s>')[0].strip()
90
  else:
@@ -93,17 +81,22 @@ def generate_response(message, history):
93
  return assistant_response
94
 
95
 
96
- # --- Gradio Chat Interface (Compatible) ---
97
- title = "Root Math TinyLlama 1.1B - Gemini-Like Chat Demo"
98
- description = "A conversational interface for the CPU-friendly TinyLlama model fine-tuned for math problems. Ask follow-up questions!"
99
 
100
  gr.ChatInterface(
101
  fn=generate_response,
102
- # Fix: Removed unsupported arguments for better Gradio version compatibility
103
- chatbot=gr.Chatbot(height=500),
104
- textbox=gr.Textbox(placeholder="Enter your math problem or follow-up question...", scale=7),
 
 
 
105
  title=title,
106
  description=description,
107
- submit_btn="Ask Model",
 
 
108
  theme="soft"
109
  ).queue().launch()
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
 
6
+ # --- Configuration (Verified) ---
 
7
  BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
8
+ # Ensure this is correct for your model repository
9
  ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU"
10
 
11
+ # Define the instruction template components
12
+ # NEW: General, helpful assistant instruction
13
+ SYSTEM_INSTRUCTION = "You are a friendly and helpful assistant named Kutti. Your primary function is to solve problems and answer questions concisely. You should never mention being a math teacher or tutor."
14
  USER_TEMPLATE = "<|user|>\n{}</s>"
15
  ASSISTANT_TEMPLATE = "<|assistant|>\n{}</s>"
16
 
17
 
18
+ # --- Model Loading Function (No change) ---
19
  def load_model():
20
  """Loads the base model and merges the LoRA adapters."""
21
  print("Loading base model...")
 
37
  print("Model loaded and merged successfully!")
38
  return tokenizer, model
39
 
 
40
  tokenizer, model = load_model()
41
 
42
 
43
+ # --- Prediction Function (No functional change, just uses new SYSTEM_INSTRUCTION) ---
44
  def generate_response(message, history):
45
  """Generates a response using chat history and the fine-tuned model."""
46
 
 
 
47
  # Start with the system instruction
48
  full_prompt = f"<|system|>\n{SYSTEM_INSTRUCTION}</s>\n"
49
 
50
+ # Append the chat history (if any)
 
 
 
 
 
51
  for user_msg, assistant_msg in history:
52
  full_prompt += USER_TEMPLATE.format(user_msg) + "\n"
53
  full_prompt += ASSISTANT_TEMPLATE.format(assistant_msg) + "\n"
 
55
  # Append the current user message and the start of the assistant's turn
56
  full_prompt += USER_TEMPLATE.format(message) + "\n"
57
  full_prompt += "<|assistant|>\n"
 
 
 
58
 
59
+ # Tokenize and generate response
60
+ inputs = tokenizer(full_prompt, return_tensors="pt")
61
  with torch.no_grad():
62
  output_tokens = model.generate(
63
  **inputs,
 
68
  pad_token_id=tokenizer.eos_token_id
69
  )
70
 
 
71
  generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
72
 
73
+ # Extract only the model's new response
74
  response_start = generated_text.rfind('<|assistant|>')
75
  if response_start != -1:
 
76
  raw_response = generated_text[response_start + len('<|assistant|>'):].strip()
77
  assistant_response = raw_response.split('</s>')[0].strip()
78
  else:
 
81
  return assistant_response
82
 
83
 
84
+ # --- Gradio Chat Interface (Changes to Title/Initial Message) ---
85
+ title = "Kutti: Your TinyLlama Problem Solver"
86
+ description = "Hello! I'm Kutti. How can I help you? Ask me anything from math problems to general questions."
87
 
88
  gr.ChatInterface(
89
  fn=generate_response,
90
+ chatbot=gr.Chatbot(
91
+ height=500,
92
+ # Initial greeting set here:
93
+ value=[(None, "Hello! I'm Kutti. How can I help you today?")]
94
+ ),
95
+ textbox=gr.Textbox(placeholder="Ask your question or problem here...", scale=7),
96
  title=title,
97
  description=description,
98
+ submit_btn="Send", # Changed button text for a more conversational feel
99
+ clear_btn="Start New Chat",
100
+ undo_btn="Undo Last Message",
101
  theme="soft"
102
  ).queue().launch()