Update app.py
Browse files
app.py
CHANGED
|
@@ -1,105 +1,102 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
import
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
#
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
full_input =
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
{
|
| 54 |
-
<|
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
"
|
| 96 |
-
"
|
| 97 |
-
"
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
if __name__ == "__main__":
|
| 105 |
-
demo.launch()
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 7 |
+
from peft import PeftModel
|
| 8 |
+
|
| 9 |
+
# Load model (will load from HF Hub or local)
|
| 10 |
+
print("Loading LTO model...")
|
| 11 |
+
|
| 12 |
+
# Base model
|
| 13 |
+
base_model_name = "unsloth/Llama-3.2-3B-bnb-4bit"
|
| 14 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
| 15 |
+
|
| 16 |
+
# Load with transformers + PEFT
|
| 17 |
+
print("Loading with transformers...")
|
| 18 |
+
|
| 19 |
+
bnb_config = BitsAndBytesConfig(
|
| 20 |
+
load_in_4bit=True,
|
| 21 |
+
bnb_4bit_quant_type="nf4",
|
| 22 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 26 |
+
base_model_name,
|
| 27 |
+
quantization_config=bnb_config,
|
| 28 |
+
device_map="auto",
|
| 29 |
+
)
|
| 30 |
+
model = PeftModel.from_pretrained(model, "./lora_model_lto")
|
| 31 |
+
model.eval()
|
| 32 |
+
|
| 33 |
+
print("Model loaded!")
|
| 34 |
+
|
| 35 |
+
# Chat template
|
| 36 |
+
SYSTEM_PROMPT = """You are LTO, a French member of CS City Discord. You do technical analysis on stocks and crypto (fundas are trash). You're aggressive in banter and use phrases like "on my wife", "kys", "die", "bozo", "dubai scammer", "fr", "ngl", "bcs". Keep it real and match the energy."""
|
| 37 |
+
|
| 38 |
+
def format_prompt(message, history):
|
| 39 |
+
# Build conversation context
|
| 40 |
+
context_parts = []
|
| 41 |
+
for user_msg, bot_msg in history[-3:]: # Last 3 exchanges
|
| 42 |
+
context_parts.append(f"[earlier] User: {user_msg}")
|
| 43 |
+
context_parts.append(f"[earlier] LTO: {bot_msg}")
|
| 44 |
+
|
| 45 |
+
if context_parts:
|
| 46 |
+
full_input = "\n".join(context_parts) + f"\nUser: {message}"
|
| 47 |
+
else:
|
| 48 |
+
full_input = f"User: {message}"
|
| 49 |
+
|
| 50 |
+
prompt = f"""<|system|>
|
| 51 |
+
{SYSTEM_PROMPT}
|
| 52 |
+
<|user|>
|
| 53 |
+
{full_input}
|
| 54 |
+
<|assistant|>
|
| 55 |
+
"""
|
| 56 |
+
return prompt
|
| 57 |
+
|
| 58 |
+
def respond(message, history):
|
| 59 |
+
prompt = format_prompt(message, history)
|
| 60 |
+
|
| 61 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
outputs = model.generate(
|
| 65 |
+
**inputs,
|
| 66 |
+
max_new_tokens=150,
|
| 67 |
+
temperature=0.75,
|
| 68 |
+
top_p=0.9,
|
| 69 |
+
do_sample=True,
|
| 70 |
+
repetition_penalty=1.15,
|
| 71 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 75 |
+
|
| 76 |
+
# Extract assistant response
|
| 77 |
+
if "<|assistant|>" in response:
|
| 78 |
+
response = response.split("<|assistant|>")[-1].strip()
|
| 79 |
+
|
| 80 |
+
# Clean up
|
| 81 |
+
response = response.replace("<|system|>", "").replace("<|user|>", "").strip()
|
| 82 |
+
if "\n" in response:
|
| 83 |
+
response = response.split("\n")[0].strip()
|
| 84 |
+
|
| 85 |
+
return response
|
| 86 |
+
|
| 87 |
+
# Create Gradio interface
|
| 88 |
+
demo = gr.ChatInterface(
|
| 89 |
+
respond,
|
| 90 |
+
title="🇫🇷 Chat with LTO",
|
| 91 |
+
description="LTO from CS City Discord. He does TA, hates fundas, and says 'on my wife' a lot. Be ready for aggressive banter!",
|
| 92 |
+
examples=[
|
| 93 |
+
"hey",
|
| 94 |
+
"what do you think of fundas?",
|
| 95 |
+
"cap",
|
| 96 |
+
"you're lying",
|
| 97 |
+
"what crypto should I buy?",
|
| 98 |
+
],
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
demo.launch()
|
|
|
|
|
|
|
|
|