Ryan Robson commited on
Commit
8bd0b76
Β·
1 Parent(s): 4972096

Improve inference quality

Browse files

- Add model.eval() for proper evaluation mode
- Simplify prompt format (remove complex conversation history)
- Add min_new_tokens=50 to force meaningful responses
- Add repetition_penalty=1.1 to reduce repetition
- Add top_k sampling for better quality
- Clean up response artifacts

Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -27,6 +27,7 @@ model = AutoModelForCausalLM.from_pretrained(
27
  print(f"πŸ”§ Loading LoRA adapter: {ADAPTER_MODEL}...")
28
  model = PeftModel.from_pretrained(model, ADAPTER_MODEL)
29
  model = model.to(device)
 
30
 
31
  print("βœ… Model loaded successfully!")
32
 
@@ -42,27 +43,23 @@ def chat(message, history):
42
  Generated response string
43
  """
44
 
45
- # Build conversation history in Mistral format
46
- prompt = ""
47
- for user_msg, bot_msg in history:
48
- prompt += f"[INST] {user_msg} [/INST] {bot_msg}</s> "
49
-
50
- # Add current message with system instruction
51
- system_message = "You are an expert educational AI assistant specializing in Texas Essential Knowledge and Skills (TEKS) standards. Provide accurate, detailed, and pedagogically sound information to help teachers and students."
52
-
53
- prompt += f"[INST] {system_message}\n\n{message} [/INST]"
54
 
55
  # Tokenize
56
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
57
 
58
- # Generate response
59
  with torch.no_grad():
60
  outputs = model.generate(
61
  **inputs,
62
- max_new_tokens=512,
63
- temperature=0.7,
64
- top_p=0.9,
 
 
65
  do_sample=True,
 
66
  pad_token_id=tokenizer.eos_token_id,
67
  eos_token_id=tokenizer.eos_token_id,
68
  )
@@ -74,6 +71,10 @@ def chat(message, history):
74
  if "[/INST]" in response:
75
  response = response.split("[/INST]")[-1].strip()
76
 
 
 
 
 
77
  return response
78
 
79
 
 
27
  print(f"πŸ”§ Loading LoRA adapter: {ADAPTER_MODEL}...")
28
  model = PeftModel.from_pretrained(model, ADAPTER_MODEL)
29
  model = model.to(device)
30
+ model.eval() # Set to evaluation mode
31
 
32
  print("βœ… Model loaded successfully!")
33
 
 
43
  Generated response string
44
  """
45
 
46
+ # Simplified prompt - just the current message
47
+ prompt = f"[INST] You are a Texas TEKS educational expert. Answer this question clearly and helpfully:\n\n{message} [/INST]"
 
 
 
 
 
 
 
48
 
49
  # Tokenize
50
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
51
 
52
+ # Generate response with better parameters
53
  with torch.no_grad():
54
  outputs = model.generate(
55
  **inputs,
56
+ max_new_tokens=300,
57
+ min_new_tokens=50,
58
+ temperature=0.8,
59
+ top_p=0.95,
60
+ top_k=50,
61
  do_sample=True,
62
+ repetition_penalty=1.1,
63
  pad_token_id=tokenizer.eos_token_id,
64
  eos_token_id=tokenizer.eos_token_id,
65
  )
 
71
  if "[/INST]" in response:
72
  response = response.split("[/INST]")[-1].strip()
73
 
74
+ # Clean up any remaining artifacts
75
+ if response.startswith(message):
76
+ response = response[len(message):].strip()
77
+
78
  return response
79
 
80