Vivek16 commited on
Commit
b2c590c
Β·
verified Β·
1 Parent(s): f94bcb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -36
app.py CHANGED
@@ -16,13 +16,14 @@ SYSTEM_INSTRUCTION = (
16
  )
17
 
18
 
19
- # --- Model Loading Function (No change from last successful load) ---
20
  def load_model():
21
  """Loads the base model and merges the LoRA adapters."""
22
  print("Loading base model...")
23
  # Load the tokenizer, which includes the necessary chat template
24
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
25
 
 
26
  model = AutoModelForCausalLM.from_pretrained(
27
  BASE_MODEL_ID,
28
  torch_dtype=torch.bfloat16,
@@ -43,7 +44,7 @@ def load_model():
43
  tokenizer, model = load_model()
44
 
45
 
46
- # --- Prediction Function (KEY MODIFICATION: Using tokenizer.apply_chat_template) ---
47
  def generate_response(message, history):
48
  """Generates a response using the official chat template and generation constraints."""
49
 
@@ -55,25 +56,26 @@ def generate_response(message, history):
55
 
56
  # Add historical messages
57
  for message_dict in history:
58
- # Gradio history items are dicts with 'role' and 'content' keys
59
  messages.append({"role": message_dict['role'], "content": message_dict['content']})
60
 
61
  # Add the current user message
62
  messages.append({"role": "user", "content": message})
63
 
64
- # 2. Apply the model's official chat template to the entire conversation
 
 
65
  full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
66
 
67
  # 3. Tokenize the input
68
  inputs = tokenizer(full_prompt, return_tensors="pt")
69
 
70
- # 4. Generate the response with anti-repetition constraints
71
  with torch.no_grad():
72
  output_tokens = model.generate(
73
  **inputs,
74
  max_new_tokens=256,
75
  do_sample=True,
76
- temperature=0.7,
77
  top_k=50,
78
  pad_token_id=tokenizer.eos_token_id,
79
  # Constraints to prevent repetitive filler:
@@ -81,37 +83,24 @@ def generate_response(message, history):
81
  repetition_penalty=1.5
82
  )
83
 
84
- # 5. Decode and clean the output
85
- # Decode the entire output sequence
86
- generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
87
 
88
- # The output contains the *entire* prompt + the new response.
89
- # We must strip the prompt and the user's final message to get the clean response.
90
-
91
- # The last user message is the end of the prompt we want to remove
92
- last_user_message = messages[-1]["content"]
93
-
94
- # Find the beginning of the model's answer, which comes after the last user message
95
- # We use the full user message content for a reliable split point.
96
-
97
- # In the full_prompt format, the model is expected to start immediately after the last user turn.
98
- # We use a simple method: find the last user message and take everything after it.
99
-
100
- try:
101
- # Find where the final user message ends in the generated text (plus a little padding for the template)
102
- split_point = generated_text.rfind(last_user_message)
103
- if split_point != -1:
104
- # Everything after the split point is the generated response
105
- assistant_response = generated_text[split_point + len(last_user_message):].strip()
106
- else:
107
- # Fallback extraction (may be less reliable)
108
- assistant_response = generated_text.strip()
109
- except Exception:
110
- # General safety fallback
111
- assistant_response = generated_text.strip()
112
-
113
- # Final cleanup to ensure no special tokens or remnants are left if skip_special_tokens=False
114
- assistant_response = assistant_response.split('</s>')[0].split('<|user|>')[0].strip()
115
 
116
  return assistant_response
117
 
 
16
  )
17
 
18
 
19
+ # --- Model Loading Function ---
20
  def load_model():
21
  """Loads the base model and merges the LoRA adapters."""
22
  print("Loading base model...")
23
  # Load the tokenizer, which includes the necessary chat template
24
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
25
 
26
+ # Force loading to CPU as per your setup
27
  model = AutoModelForCausalLM.from_pretrained(
28
  BASE_MODEL_ID,
29
  torch_dtype=torch.bfloat16,
 
44
  tokenizer, model = load_model()
45
 
46
 
47
+ # --- Prediction Function (Modified for MAX stability and lower temperature) ---
48
  def generate_response(message, history):
49
  """Generates a response using the official chat template and generation constraints."""
50
 
 
56
 
57
  # Add historical messages
58
  for message_dict in history:
 
59
  messages.append({"role": message_dict['role'], "content": message_dict['content']})
60
 
61
  # Add the current user message
62
  messages.append({"role": "user", "content": message})
63
 
64
+ # 2. Apply the model's official chat template
65
+ # NOTE: The "TinyLlama/TinyLlama-1.1B-Chat-v1.0" model expects a template like:
66
+ # <|system|>\nSYSTEM_INSTRUCTION</s>\n<|user|>\nMESSAGE</s>\n<|assistant|>\n
67
  full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
68
 
69
  # 3. Tokenize the input
70
  inputs = tokenizer(full_prompt, return_tensors="pt")
71
 
72
+ # 4. Generate the response with anti-repetition constraints and LOWER TEMPERATURE
73
  with torch.no_grad():
74
  output_tokens = model.generate(
75
  **inputs,
76
  max_new_tokens=256,
77
  do_sample=True,
78
+ temperature=0.6, # Slightly lower temp for less gibberish
79
  top_k=50,
80
  pad_token_id=tokenizer.eos_token_id,
81
  # Constraints to prevent repetitive filler:
 
83
  repetition_penalty=1.5
84
  )
85
 
86
+ # 5. Decode and clean the output using skip_special_tokens=True for max cleanup
87
+ # We still need to find where the *new* response begins.
88
+ generated_text_with_prompt = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
89
 
90
+ # Extract only the model's new response by finding the last <|assistant|> tag
91
+ # The last tag marks the beginning of the new response.
92
+ assistant_prefix_tag = "<|assistant|>"
93
+ response_start_index = generated_text_with_prompt.rfind(assistant_prefix_tag)
94
+
95
+ if response_start_index != -1:
96
+ # Get everything after the last <|assistant|> tag
97
+ raw_response = generated_text_with_prompt[response_start_index + len(assistant_prefix_tag):].strip()
98
+
99
+ # Clean up any trailing end-of-sequence tags (</s>) or user tags (<|user|>)
100
+ assistant_response = raw_response.split("</s>")[0].split("<|user|>")[0].strip()
101
+ else:
102
+ # Fallback to the decoded text if the tag is not found (and hope for the best)
103
+ assistant_response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  return assistant_response
106