| | import os |
| | os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" |
| |
|
| | import gradio as gr |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| | from peft import PeftModel |
| |
|
| | |
| | print("Loading LTO model...") |
| |
|
| | |
| | base_model_name = "unsloth/Llama-3.2-3B-bnb-4bit" |
| | tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| |
|
| | |
| | print("Loading with transformers...") |
| |
|
| | bnb_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=torch.float16, |
| | ) |
| |
|
| | model = AutoModelForCausalLM.from_pretrained( |
| | base_model_name, |
| | quantization_config=bnb_config, |
| | device_map="auto", |
| | ) |
| | model = PeftModel.from_pretrained(model, "./lora_model_lto") |
| | model.eval() |
| |
|
| | print("Model loaded!") |
| |
|
| | |
| | SYSTEM_PROMPT = """You are LTO, a French member of CS City Discord. You do technical analysis on stocks and crypto (fundas are trash). You're aggressive in banter and use phrases like "on my wife", "kys", "die", "bozo", "dubai scammer", "fr", "ngl", "bcs". Keep it real and match the energy.""" |
| |
|
| | def format_prompt(message, history): |
| | |
| | context_parts = [] |
| | for user_msg, bot_msg in history[-3:]: |
| | context_parts.append(f"[earlier] User: {user_msg}") |
| | context_parts.append(f"[earlier] LTO: {bot_msg}") |
| | |
| | if context_parts: |
| | full_input = "\n".join(context_parts) + f"\nUser: {message}" |
| | else: |
| | full_input = f"User: {message}" |
| | |
| | prompt = f"""<|system|> |
| | {SYSTEM_PROMPT} |
| | <|user|> |
| | {full_input} |
| | <|assistant|> |
| | """ |
| | return prompt |
| |
|
| | def respond(message, history): |
| | prompt = format_prompt(message, history) |
| | |
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| | |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=150, |
| | temperature=0.75, |
| | top_p=0.9, |
| | do_sample=True, |
| | repetition_penalty=1.15, |
| | pad_token_id=tokenizer.eos_token_id, |
| | ) |
| | |
| | response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | |
| | if "<|assistant|>" in response: |
| | response = response.split("<|assistant|>")[-1].strip() |
| | |
| | |
| | response = response.replace("<|system|>", "").replace("<|user|>", "").strip() |
| | if "\n" in response: |
| | response = response.split("\n")[0].strip() |
| | |
| | return response |
| |
|
| | |
| | demo = gr.ChatInterface( |
| | respond, |
| | title="🇫🇷 Chat with LTO", |
| | description="LTO from CS City Discord. He does TA, hates fundas, and says 'on my wife' a lot. Be ready for aggressive banter!", |
| | examples=[ |
| | "hey", |
| | "what do you think of fundas?", |
| | "cap", |
| | "you're lying", |
| | "what crypto should I buy?", |
| | ], |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|