simonper commited on
Commit
14d11ec
·
verified ·
1 Parent(s): 91771ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -82
app.py CHANGED
@@ -1,125 +1,123 @@
1
  import gradio as gr
2
- from llama_cpp import Llama
 
 
3
 
4
- # Initialize the model
5
- llm = Llama.from_pretrained(
6
- repo_id="simonper/fine-tuned-gguf-modal1",
7
- filename="Llama-3.2-1B.Q8_0.gguf",
8
- n_ctx=2048,
9
- n_threads=2,
10
- verbose=False
11
- )
12
 
13
- # --- 1. LLAMA 3 SPECIFIC FORMATTING ---
14
- def format_llama3_prompt(system_message: str, history: list[dict], user_message: str) -> str:
15
- """
16
- Formats the conversation using official Llama 3 special tokens.
17
- """
18
- formatted_prompt = "<|begin_of_text|>"
19
-
20
- # Add System Message
21
- formatted_prompt += f"<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>"
22
-
23
- # Add History
24
- for turn in history:
25
- role = turn['role']
26
- content = turn['content']
27
- formatted_prompt += f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
28
-
29
- # Add Current User Message
30
- formatted_prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_message}<|eot_id|>"
31
-
32
- # Add Assistant Header (ready for generation)
33
- formatted_prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n"
34
-
35
- return formatted_prompt
36
 
37
- # --- 2. ENHANCED SYSTEM PROMPTS ---
38
- def get_system_prompt(style_mode):
39
- """
40
- Returns a rich persona definition based on the selected style.
41
- """
42
- base_instruction = "You are a helpful and intelligent AI assistant."
 
43
 
44
- prompts = {
45
- "Normal": (
46
- f"{base_instruction} Answer the user's questions clearly and concisely."
47
- ),
48
- "Professional": (
49
- f"{base_instruction} You are a senior corporate executive. "
50
- "Your tone is strictly professional, polite, and business-oriented. "
51
- "Use formal vocabulary, avoid slang, and structure your answers with bullet points where possible."
52
- ),
53
- "Shakespeare": (
54
- f"{base_instruction} You are William Shakespeare. "
55
- "You speak only in Early Modern English (using thee, thou, hath, etc.). "
56
- "Your responses should be poetic, dramatic, and perhaps slightly archaic."
57
- ),
58
- "Funny/Ironic": (
59
- f"{base_instruction} You are a sarcastic comedian who loves irony. "
60
- "While you must still answer the user's question, wrap the answer in dry humor, "
61
- "witty remarks, and self-deprecating jokes. Do not be overly polite."
62
  )
63
- }
64
- return prompts.get(style_mode, prompts["Normal"])
 
 
 
 
65
 
 
66
  def respond(
67
  message,
68
  history: list[dict],
69
- system_message_dummy,
70
  max_tokens,
71
  temperature,
72
  top_p,
73
  repetition_penalty,
74
  style_mode,
75
  ):
76
- system_prompt = get_system_prompt(style_mode)
 
 
 
 
 
77
 
 
78
  if len(history) > 10:
79
  history = history[-10:]
80
 
81
- # 3. Build the prompt using Llama 3 template
82
- prompt = format_llama3_prompt(system_prompt, history, message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- # 4. Generate
85
- output = llm(
86
- prompt,
87
- max_tokens=int(max_tokens),
 
88
  temperature=float(temperature),
89
  top_p=float(top_p),
90
- repeat_penalty=float(repetition_penalty),
91
- stop=["<|eot_id|>", "<|end_of_text|>"],
92
- echo=False
93
  )
94
 
95
- reply = output["choices"][0]["text"].strip()
96
- return reply
 
97
 
98
  # --- 3. GUI SETUP ---
 
99
  chatbot = gr.ChatInterface(
100
  respond,
101
  type="messages",
102
  additional_inputs=[
103
  gr.Textbox(value="", label="System Prompt (Hidden)", visible=False),
104
-
105
  gr.Slider(minimum=1, maximum=1024, value=512, label="Max New Tokens"),
106
  gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Temperature"),
107
  gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p"),
108
-
109
  gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty"),
110
-
111
- gr.Dropdown(
112
- choices=["Normal", "Professional", "Shakespeare", "Funny/Ironic"],
113
- value="Normal",
114
- label="Choose the Style / Tone"
115
- )
116
  ],
117
  )
118
 
119
  with gr.Blocks() as demo:
120
- gr.Markdown("# Advanced Chat Bot (Llama 3.2 1B)")
121
- with gr.Sidebar():
122
- gr.LoginButton()
123
  chatbot.render()
124
 
125
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ from threading import Thread
5
 
6
+ # --- 1. SETUP MODEL & TOKENIZER ---
7
+ # User requested the BASE (Untrained) version, not Instruct.
8
+ MODEL_ID = "meta-llama/Llama-3.2-1B"
 
 
 
 
 
9
 
10
+ # Check for GPU, otherwise fallback to CPU
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ print(f"Loading base model on: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ try:
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ MODEL_ID,
18
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
19
+ device_map="auto"
20
+ )
21
 
22
+ # CRITICAL FIX FOR BASE MODELS:
23
+ # Base models often do not have a 'chat_template' defined in their config
24
+ # because they aren't meant for chat. We must manually assign the Llama 3
25
+ # template so the code doesn't crash when using apply_chat_template.
26
+ if tokenizer.chat_template is None:
27
+ print("Base model detected: Assigning default Llama 3 chat template...")
28
+ tokenizer.chat_template = (
29
+ "{% set loop_messages = messages %}"
30
+ "{% for message in loop_messages %}"
31
+ "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}"
32
+ "{% if loop.index0 == 0 %}"
33
+ "{% set content = '<|begin_of_text|>' + content %}"
34
+ "{% endif %}"
35
+ "{{ content }}"
36
+ "{% endfor %}"
37
+ "{% if add_generation_prompt %}"
38
+ "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
39
+ "{% endif %}"
40
  )
41
+ # Ensure special tokens used in template exist in tokenizer
42
+ tokenizer.pad_token_id = tokenizer.eos_token_id
43
+
44
+ except Exception as e:
45
+ print(f"Error loading model. Ensure you have a valid HF_TOKEN and access to the gated repo. Error: {e}")
46
+ raise e
47
 
48
+ # --- 2. GENERATION FUNCTION ---
49
  def respond(
50
  message,
51
  history: list[dict],
52
+ system_message_dummy,
53
  max_tokens,
54
  temperature,
55
  top_p,
56
  repetition_penalty,
57
  style_mode,
58
  ):
59
+ # Base models ignore system prompts mostly, but we include it for structure
60
+ system_prompt = "You are an AI assistant."
61
+ if style_mode == "Shakespeare":
62
+ system_prompt = "You are William Shakespeare. Speak in Early Modern English."
63
+ elif style_mode == "Funny/Ironic":
64
+ system_prompt = "You are a sarcastic comedian."
65
 
66
+ # Context Window Management
67
  if len(history) > 10:
68
  history = history[-10:]
69
 
70
+ # Build messages
71
+ messages = [{"role": "system", "content": system_prompt}]
72
+ for turn in history:
73
+ messages.append({"role": turn['role'], "content": turn['content']})
74
+ messages.append({"role": "user", "content": message})
75
+
76
+ # Apply Template
77
+ input_ids = tokenizer.apply_chat_template(
78
+ messages,
79
+ add_generation_prompt=True,
80
+ return_tensors="pt"
81
+ ).to(model.device)
82
+
83
+ terminators = [
84
+ tokenizer.eos_token_id,
85
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
86
+ ]
87
 
88
+ # Generate
89
+ outputs = model.generate(
90
+ input_ids,
91
+ max_new_tokens=int(max_tokens),
92
+ eos_token_id=terminators,
93
  temperature=float(temperature),
94
  top_p=float(top_p),
95
+ repetition_penalty=float(repetition_penalty),
96
+ do_sample=True,
 
97
  )
98
 
99
+ response = outputs[0][input_ids.shape[-1]:]
100
+ decoded_response = tokenizer.decode(response, skip_special_tokens=True)
101
+ return decoded_response
102
 
103
  # --- 3. GUI SETUP ---
104
+ # (Kept identical to previous, just updated title)
105
  chatbot = gr.ChatInterface(
106
  respond,
107
  type="messages",
108
  additional_inputs=[
109
  gr.Textbox(value="", label="System Prompt (Hidden)", visible=False),
 
110
  gr.Slider(minimum=1, maximum=1024, value=512, label="Max New Tokens"),
111
  gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Temperature"),
112
  gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p"),
 
113
  gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty"),
114
+ gr.Dropdown(choices=["Normal", "Professional", "Shakespeare", "Funny/Ironic"], value="Normal", label="Style"),
 
 
 
 
 
115
  ],
116
  )
117
 
118
  with gr.Blocks() as demo:
119
+ gr.Markdown("# Chat with Llama 3.2 1B (Base/Untrained)")
120
+ gr.Markdown("> **Warning:** You are running the base model. It will likely hallucinate or autocomplete text rather than chatting normally.")
 
121
  chatbot.render()
122
 
123
  if __name__ == "__main__":