abigubi commited on
Commit
0c528f4
·
verified ·
1 Parent(s): 9ab8bae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -105
app.py CHANGED
@@ -1,105 +1,102 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from peft import PeftModel
5
- import os
6
-
7
- # Load model (will load from HF Hub or local)
8
- print("Loading LTO model...")
9
-
10
- # Base model
11
- base_model_name = "unsloth/Llama-3.2-3B-bnb-4bit"
12
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
13
-
14
- # Try to load with LoRA adapter
15
- try:
16
- from unsloth import FastLanguageModel
17
- model, tokenizer = FastLanguageModel.from_pretrained(
18
- model_name="./lora_model_lto", # Local path in Space
19
- max_seq_length=512,
20
- dtype=None,
21
- load_in_4bit=True,
22
- )
23
- FastLanguageModel.for_inference(model)
24
- print("Loaded with Unsloth!")
25
- except:
26
- print("Loading with transformers...")
27
- model = AutoModelForCausalLM.from_pretrained(
28
- base_model_name,
29
- torch_dtype=torch.float16,
30
- device_map="auto",
31
- load_in_4bit=True,
32
- )
33
- model = PeftModel.from_pretrained(model, "./lora_model_lto")
34
-
35
- print("Model loaded!")
36
-
37
- # Chat template
38
- SYSTEM_PROMPT = """You are LTO, a French member of CS City Discord. You do technical analysis on stocks and crypto (fundas are trash). You're aggressive in banter and use phrases like "on my wife", "kys", "die", "bozo", "dubai scammer", "fr", "ngl", "bcs". Keep it real and match the energy."""
39
-
40
- def format_prompt(message, history):
41
- # Build conversation context
42
- context_parts = []
43
- for user_msg, bot_msg in history[-3:]: # Last 3 exchanges
44
- context_parts.append(f"[earlier] User: {user_msg}")
45
- context_parts.append(f"[earlier] LTO: {bot_msg}")
46
-
47
- if context_parts:
48
- full_input = "\n".join(context_parts) + f"\nUser: {message}"
49
- else:
50
- full_input = f"User: {message}"
51
-
52
- prompt = f"""<|system|>
53
- {SYSTEM_PROMPT}
54
- <|user|>
55
- {full_input}
56
- <|assistant|>
57
- """
58
- return prompt
59
-
60
- def respond(message, history):
61
- prompt = format_prompt(message, history)
62
-
63
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
64
-
65
- with torch.no_grad():
66
- outputs = model.generate(
67
- **inputs,
68
- max_new_tokens=150,
69
- temperature=0.75,
70
- top_p=0.9,
71
- do_sample=True,
72
- repetition_penalty=1.15,
73
- pad_token_id=tokenizer.eos_token_id,
74
- )
75
-
76
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
77
-
78
- # Extract assistant response
79
- if "<|assistant|>" in response:
80
- response = response.split("<|assistant|>")[-1].strip()
81
-
82
- # Clean up
83
- response = response.replace("<|system|>", "").replace("<|user|>", "").strip()
84
- if "\n" in response:
85
- response = response.split("\n")[0].strip()
86
-
87
- return response
88
-
89
- # Create Gradio interface
90
- demo = gr.ChatInterface(
91
- respond,
92
- title="🇫🇷 Chat with LTO",
93
- description="LTO from CS City Discord. He does TA, hates fundas, and says 'on my wife' a lot. Be ready for aggressive banter!",
94
- examples=[
95
- "hey",
96
- "what do you think of fundas?",
97
- "cap",
98
- "you're lying",
99
- "what crypto should I buy?",
100
- ],
101
- theme=gr.themes.Soft(),
102
- )
103
-
104
- if __name__ == "__main__":
105
- demo.launch()
 
1
+ import os
2
+ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
+ from peft import PeftModel
8
+
9
+ # Load model (will load from HF Hub or local)
10
+ print("Loading LTO model...")
11
+
12
+ # Base model
13
+ base_model_name = "unsloth/Llama-3.2-3B-bnb-4bit"
14
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
15
+
16
+ # Load with transformers + PEFT
17
+ print("Loading with transformers...")
18
+
19
+ bnb_config = BitsAndBytesConfig(
20
+ load_in_4bit=True,
21
+ bnb_4bit_quant_type="nf4",
22
+ bnb_4bit_compute_dtype=torch.float16,
23
+ )
24
+
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ base_model_name,
27
+ quantization_config=bnb_config,
28
+ device_map="auto",
29
+ )
30
+ model = PeftModel.from_pretrained(model, "./lora_model_lto")
31
+ model.eval()
32
+
33
+ print("Model loaded!")
34
+
35
+ # Chat template
36
+ SYSTEM_PROMPT = """You are LTO, a French member of CS City Discord. You do technical analysis on stocks and crypto (fundas are trash). You're aggressive in banter and use phrases like "on my wife", "kys", "die", "bozo", "dubai scammer", "fr", "ngl", "bcs". Keep it real and match the energy."""
37
+
38
+ def format_prompt(message, history):
39
+ # Build conversation context
40
+ context_parts = []
41
+ for user_msg, bot_msg in history[-3:]: # Last 3 exchanges
42
+ context_parts.append(f"[earlier] User: {user_msg}")
43
+ context_parts.append(f"[earlier] LTO: {bot_msg}")
44
+
45
+ if context_parts:
46
+ full_input = "\n".join(context_parts) + f"\nUser: {message}"
47
+ else:
48
+ full_input = f"User: {message}"
49
+
50
+ prompt = f"""<|system|>
51
+ {SYSTEM_PROMPT}
52
+ <|user|>
53
+ {full_input}
54
+ <|assistant|>
55
+ """
56
+ return prompt
57
+
58
+ def respond(message, history):
59
+ prompt = format_prompt(message, history)
60
+
61
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
62
+
63
+ with torch.no_grad():
64
+ outputs = model.generate(
65
+ **inputs,
66
+ max_new_tokens=150,
67
+ temperature=0.75,
68
+ top_p=0.9,
69
+ do_sample=True,
70
+ repetition_penalty=1.15,
71
+ pad_token_id=tokenizer.eos_token_id,
72
+ )
73
+
74
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
+
76
+ # Extract assistant response
77
+ if "<|assistant|>" in response:
78
+ response = response.split("<|assistant|>")[-1].strip()
79
+
80
+ # Clean up
81
+ response = response.replace("<|system|>", "").replace("<|user|>", "").strip()
82
+ if "\n" in response:
83
+ response = response.split("\n")[0].strip()
84
+
85
+ return response
86
+
87
+ # Create Gradio interface
88
+ demo = gr.ChatInterface(
89
+ respond,
90
+ title="🇫🇷 Chat with LTO",
91
+ description="LTO from CS City Discord. He does TA, hates fundas, and says 'on my wife' a lot. Be ready for aggressive banter!",
92
+ examples=[
93
+ "hey",
94
+ "what do you think of fundas?",
95
+ "cap",
96
+ "you're lying",
97
+ "what crypto should I buy?",
98
+ ],
99
+ )
100
+
101
+ if __name__ == "__main__":
102
+ demo.launch()