Vivek16 commited on
Commit
e4cb5da
Β·
verified Β·
1 Parent(s): b2c590c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -63
app.py CHANGED
@@ -3,27 +3,23 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
 
6
- # --- Configuration ---
7
  BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
8
  ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU"
9
 
10
- # Define the single, strong system instruction
11
- SYSTEM_INSTRUCTION = (
12
- "You are a friendly, helpful, and highly SKILLED assistant named Kutti. "
13
- "Your responses MUST be concise and direct. You can handle any conversation, "
14
- "but when asked a problem (especially math), provide the correct step-by-step solution. "
15
- "DO NOT use excessive conversational filler or repetitive phrases. Stick to the point."
16
- )
17
 
18
 
19
  # --- Model Loading Function ---
20
  def load_model():
21
  """Loads the base model and merges the LoRA adapters."""
22
  print("Loading base model...")
23
- # Load the tokenizer, which includes the necessary chat template
24
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
25
-
26
- # Force loading to CPU as per your setup
27
  model = AutoModelForCausalLM.from_pretrained(
28
  BASE_MODEL_ID,
29
  torch_dtype=torch.bfloat16,
@@ -31,94 +27,84 @@ def load_model():
31
  )
32
 
33
  print("Loading and merging PEFT adapters...")
 
34
  model = PeftModel.from_pretrained(model, ADAPTER_MODEL_ID)
35
  model = model.merge_and_unload()
36
  model.eval()
37
 
 
38
  if tokenizer.pad_token is None:
39
  tokenizer.pad_token = tokenizer.eos_token
40
 
41
  print("Model loaded and merged successfully!")
42
  return tokenizer, model
43
 
 
44
  tokenizer, model = load_model()
45
 
46
 
47
- # --- Prediction Function (Modified for MAX stability and lower temperature) ---
48
  def generate_response(message, history):
49
- """Generates a response using the official chat template and generation constraints."""
50
 
51
- # 1. Prepare messages list for the chat template
52
- messages = []
53
 
54
- # Add the system instruction first
55
- messages.append({"role": "system", "content": SYSTEM_INSTRUCTION})
56
-
57
- # Add historical messages
58
- for message_dict in history:
59
- messages.append({"role": message_dict['role'], "content": message_dict['content']})
60
-
61
- # Add the current user message
62
- messages.append({"role": "user", "content": message})
63
-
64
- # 2. Apply the model's official chat template
65
- # NOTE: The "TinyLlama/TinyLlama-1.1B-Chat-v1.0" model expects a template like:
66
- # <|system|>\nSYSTEM_INSTRUCTION</s>\n<|user|>\nMESSAGE</s>\n<|assistant|>\n
67
- full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
68
 
69
- # 3. Tokenize the input
70
  inputs = tokenizer(full_prompt, return_tensors="pt")
71
-
72
- # 4. Generate the response with anti-repetition constraints and LOWER TEMPERATURE
73
  with torch.no_grad():
74
  output_tokens = model.generate(
75
  **inputs,
76
  max_new_tokens=256,
77
  do_sample=True,
78
- temperature=0.6, # Slightly lower temp for less gibberish
79
  top_k=50,
80
- pad_token_id=tokenizer.eos_token_id,
81
- # Constraints to prevent repetitive filler:
82
- no_repeat_ngram_size=5,
83
- repetition_penalty=1.5
84
  )
85
 
86
- # 5. Decode and clean the output using skip_special_tokens=True for max cleanup
87
- # We still need to find where the *new* response begins.
88
- generated_text_with_prompt = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
89
 
90
- # Extract only the model's new response by finding the last <|assistant|> tag
91
- # The last tag marks the beginning of the new response.
92
- assistant_prefix_tag = "<|assistant|>"
93
- response_start_index = generated_text_with_prompt.rfind(assistant_prefix_tag)
94
-
95
- if response_start_index != -1:
96
- # Get everything after the last <|assistant|> tag
97
- raw_response = generated_text_with_prompt[response_start_index + len(assistant_prefix_tag):].strip()
98
-
99
- # Clean up any trailing end-of-sequence tags (</s>) or user tags (<|user|>)
100
- assistant_response = raw_response.split("</s>")[0].split("<|user|>")[0].strip()
101
  else:
102
- # Fallback to the decoded text if the tag is not found (and hope for the best)
103
- assistant_response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
104
 
105
  return assistant_response
106
 
107
 
108
- # --- Gradio Chat Interface (No Change) ---
109
- title = "Kutti: Your TinyLlama Problem Solver"
110
- description = "Hello! I'm Kutti. How can I help you? Ask me anything from math problems to general questions."
111
 
112
  gr.ChatInterface(
113
  fn=generate_response,
114
- chatbot=gr.Chatbot(
115
- height=500,
116
- type='messages',
117
- value=[{'role': 'assistant', 'content': "Hello! I'm Kutti. How can I help you today?"}]
118
- ),
119
- textbox=gr.Textbox(placeholder="Ask your question or problem here...", scale=7),
120
  title=title,
121
  description=description,
122
- submit_btn="Send",
 
 
123
  theme="soft"
124
  ).queue().launch()
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
 
6
+ # --- Configuration (Verified) ---
7
  BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
8
+ # Ensure this is correct for your model repository
9
  ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU"
10
 
11
+ # Define the instruction template components
12
+ SYSTEM_INSTRUCTION = "Solve the following math problem:"
13
+ USER_TEMPLATE = "<|user|>\n{}</s>"
14
+ ASSISTANT_TEMPLATE = "<|assistant|>\n{}</s>"
 
 
 
15
 
16
 
17
  # --- Model Loading Function ---
18
  def load_model():
19
  """Loads the base model and merges the LoRA adapters."""
20
  print("Loading base model...")
21
+ # Use bfloat16 for efficiency on CPU
22
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
 
 
23
  model = AutoModelForCausalLM.from_pretrained(
24
  BASE_MODEL_ID,
25
  torch_dtype=torch.bfloat16,
 
27
  )
28
 
29
  print("Loading and merging PEFT adapters...")
30
+ # Load the trained LoRA adapters
31
  model = PeftModel.from_pretrained(model, ADAPTER_MODEL_ID)
32
  model = model.merge_and_unload()
33
  model.eval()
34
 
35
+ # Ensure pad token is set for generation
36
  if tokenizer.pad_token is None:
37
  tokenizer.pad_token = tokenizer.eos_token
38
 
39
  print("Model loaded and merged successfully!")
40
  return tokenizer, model
41
 
42
+ # Load the model outside the prediction function for efficiency
43
  tokenizer, model = load_model()
44
 
45
 
46
+ # --- Prediction Function for gr.ChatInterface ---
47
  def generate_response(message, history):
48
+ """Generates a response using chat history and the fine-tuned model."""
49
 
50
+ # 1. Build the full prompt including System Instruction, History, and current Message
 
51
 
52
+ # Start with the system instruction
53
+ full_prompt = f"<|system|>\n{SYSTEM_INSTRUCTION}</s>\n"
54
+
55
+ # Append the chat history (if any)
56
+ for user_msg, assistant_msg in history:
57
+ full_prompt += USER_TEMPLATE.format(user_msg) + "\n"
58
+ full_prompt += ASSISTANT_TEMPLATE.format(assistant_msg) + "\n"
59
+
60
+ # Append the current user message and the start of the assistant's turn
61
+ full_prompt += USER_TEMPLATE.format(message) + "\n"
62
+ full_prompt += "<|assistant|>\n"
63
+
64
+ print(f"--- Full Prompt ---\n{full_prompt}")
 
65
 
66
+ # 2. Tokenize the input
67
  inputs = tokenizer(full_prompt, return_tensors="pt")
68
+
69
+ # 3. Generate the response (on CPU)
70
  with torch.no_grad():
71
  output_tokens = model.generate(
72
  **inputs,
73
  max_new_tokens=256,
74
  do_sample=True,
75
+ temperature=0.7,
76
  top_k=50,
77
+ pad_token_id=tokenizer.eos_token_id
 
 
 
78
  )
79
 
80
+ # 4. Decode the output
81
+ generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
 
82
 
83
+ # 5. Extract only the model's new response
84
+ # Find the start of the assistant's turn in the output and everything after it
85
+ response_start = generated_text.rfind('<|assistant|>')
86
+ if response_start != -1:
87
+ # Get the text after <|assistant|> and strip the trailing </s>
88
+ raw_response = generated_text[response_start + len('<|assistant|>'):].strip()
89
+ assistant_response = raw_response.split('</s>')[0].strip()
 
 
 
 
90
  else:
91
+ assistant_response = "Error: Could not parse model output."
 
92
 
93
  return assistant_response
94
 
95
 
96
+ # --- Gradio Chat Interface ---
97
+ title = "Root Math TinyLlama 1.1B - Gemini-Like Chat Demo"
98
+ description = "A conversational interface for the CPU-friendly TinyLlama model fine-tuned for math problems. Ask follow-up questions!"
99
 
100
  gr.ChatInterface(
101
  fn=generate_response,
102
+ chatbot=gr.Chatbot(height=500), # Makes the chat history window taller
103
+ textbox=gr.Textbox(placeholder="Enter your math problem or follow-up question...", scale=7),
 
 
 
 
104
  title=title,
105
  description=description,
106
+ submit_btn="Ask Model",
107
+ clear_btn="Start New Chat",
108
+ undo_btn="Undo Last Message",
109
  theme="soft"
110
  ).queue().launch()