| | |
| |
|
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| |
|
| |
|
| | |
| | model_name = "models/Llama-3.2-1B-Instruct" |
| | tok = None |
| | lm = None |
| |
|
| |
|
| |
|
| | def chat_current(system_prompt: str, user_prompt: str) -> str: |
| | """ |
| | Current implementation (same as server.py) - will show warnings |
| | """ |
| | print("🔴 Running CURRENT implementation (with warnings)...") |
| | |
| | messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": user_prompt}, |
| | ] |
| |
|
| | input_ids = tok.apply_chat_template( |
| | messages, |
| | add_generation_prompt=True, |
| | return_tensors="pt" |
| | ).to(lm.device) |
| |
|
| | with torch.inference_mode(): |
| | output_ids = lm.generate( |
| | input_ids, |
| | max_new_tokens=2048, |
| | do_sample=True, |
| | temperature=0.2, |
| | repetition_penalty=1.1, |
| | top_k=100, |
| | top_p=0.95, |
| | ) |
| |
|
| | answer = tok.decode( |
| | output_ids[0][input_ids.shape[-1]:], |
| | skip_special_tokens=True, |
| | clean_up_tokenization_spaces=True, |
| | ) |
| | return answer.strip() |
| |
|
| |
|
| |
|
| |
|
| | def chat_fixed(system_prompt: str, user_prompt: str) -> str: |
| | """ |
| | Fixed implementation - proper attention mask and pad token |
| | """ |
| | print("🟢 Running FIXED implementation (no warnings)...") |
| | |
| | messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": user_prompt}, |
| | ] |
| |
|
| | |
| | inputs = tok.apply_chat_template( |
| | messages, |
| | add_generation_prompt=True, |
| | return_tensors="pt", |
| | return_dict=True |
| | ) |
| | |
| | |
| | input_ids = inputs["input_ids"].to(lm.device) |
| | attention_mask = inputs["attention_mask"].to(lm.device) |
| |
|
| | with torch.inference_mode(): |
| | output_ids = lm.generate( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | pad_token_id=tok.eos_token_id, |
| | max_new_tokens=2048, |
| | do_sample=True, |
| | temperature=0.2, |
| | repetition_penalty=1.1, |
| | top_k=100, |
| | top_p=0.95, |
| | ) |
| |
|
| | answer = tok.decode( |
| | output_ids[0][input_ids.shape[-1]:], |
| | skip_special_tokens=True, |
| | clean_up_tokenization_spaces=True, |
| | ) |
| | return answer.strip() |
| |
|
| |
|
| |
|
| |
|
| | def compare_generations(): |
| | """Compare both implementations""" |
| | system_prompt = "You are a helpful assistant who tries to help answer the user's question." |
| | user_prompt = "Create a report on anxiety in work. How do I manage time and stress effectively?" |
| | |
| | print("=" * 60) |
| | print("COMPARING GENERATION METHODS") |
| | print("=" * 60) |
| | print(f"System: {system_prompt}") |
| | print(f"User: {user_prompt}") |
| | print("=" * 60) |
| | |
| | |
| | print("\n" + "=" * 60) |
| | current_output = chat_current(system_prompt, user_prompt) |
| | print(f"CURRENT OUTPUT:\n{current_output}") |
| | |
| | print("\n" + "=" * 60) |
| | |
| | fixed_output = chat_fixed(system_prompt, user_prompt) |
| | print(f"FIXED OUTPUT:\n{fixed_output}") |
| | |
| | print("\n" + "=" * 60) |
| | print("COMPARISON:") |
| | print(f"Outputs are identical: {current_output == fixed_output}") |
| | print(f"Current length: {len(current_output)} chars") |
| | print(f"Fixed length: {len(fixed_output)} chars") |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| | def filter_by_word_count(data, max_words=3): |
| | """Return only phrases with word count <= max_words.""" |
| | return {k: v for k, v in data.items() if len(v.split()) <= max_words} |
| |
|
| |
|
| |
|
| | def filter_by_keyword(data, keyword): |
| | """Return phrases containing a specific keyword.""" |
| | return {k: v for k, v in data.items() if keyword.lower() in v.lower()} |
| |
|
| |
|
| |
|