File size: 4,339 Bytes
2499997
 
 
43c9011
c4b05c8
 
2499997
 
c4b05c8
 
452fd6f
 
2499997
 
 
 
 
 
452fd6f
 
2499997
452fd6f
 
2499997
452fd6f
 
 
cedd2f6
452fd6f
 
2499997
c4b05c8
2499997
 
 
 
452fd6f
2499997
558f45b
 
 
c4b05c8
 
 
 
558f45b
452fd6f
2499997
452fd6f
2499997
 
452fd6f
2499997
 
 
452fd6f
2499997
 
452fd6f
2499997
 
 
 
 
 
452fd6f
 
2499997
 
 
 
452fd6f
c8de34f
 
452fd6f
2499997
a9f0d06
 
16f56d5
a9f0d06
452fd6f
2499997
 
 
78a0e5d
82b97a4
2499997
0241424
a9f0d06
 
2499997
452fd6f
54b8d41
 
 
 
 
 
6aaeafb
54b8d41
 
 
 
 
 
 
 
 
 
6917daa
452fd6f
2499997
 
452fd6f
2499997
 
 
 
 
5914757
78a0e5d
2499997
452fd6f
2499997
 
 
452fd6f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel


# Model configuration
MODEL_NAME = "jmcinern/qwen3-8B-cpt-sft-awq"
DPO_ADAPTER = "jmcinern/qomhra-8B-awq-dpo-beta-0.5-checkpoint-checkpoint-100"
THINK_TAG_PATTERN = re.compile(r"<think>.*?</think>\s*", flags=re.DOTALL)


class ChatBot:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.loading = True
        self.load_model()

    def load_model(self):
        """Load model and tokenizer sequentially"""
        try:
            print("Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                MODEL_NAME, trust_remote_code=True
            )
            print(self.tokenizer.eos_token_id)
            print("Tokenizer loaded!")

            print("Loading model...")
            base_model = AutoModelForCausalLM.from_pretrained(
                MODEL_NAME,
                trust_remote_code=True,
                device_map="auto",
                torch_dtype="auto",
                low_cpu_mem_usage=True,
            )
            self.model = base_model
            # Uncomment to use DPO adapter, I suspect it is giving problem.
            '''
            self.model = PeftModel.from_pretrained(
                base_model,
                DPO_ADAPTER
            )
            '''
            print("Model loaded!")
        except Exception as e:
            print(f"Error loading model: {e}")
        finally:
            self.loading = False

    def chat(self, message, history):
        if self.loading:
            return history + [(message, "Model is loading, please wait...")]

        if not self.model:
            return history + [(message, "Model failed to load")]

        # Build messages
        messages = []
        for user_msg, bot_msg in history:
            messages.append({"role": "user", "content": user_msg})
            messages.append({"role": "assistant", "content": bot_msg})
        messages.append({"role": "user", "content": message})

        # Apply chat template and strip thinking tags
        prompt = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
        )
        prompt = THINK_TAG_PATTERN.sub("", prompt)

        print("----------PROMPT--------------")
        print(prompt)
        # Tokenize
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        # Set EOS seen in training (per Qwen chat template) 
        stop_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
        
        # Generate response
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=2048,
                temperature=0.7,
                do_sample=True,
                return_dict_in_generate=True,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=stop_id
            )

        # Decode and clean response, with multiple debugs
        # Inspect token IDs
        gen_ids = outputs.sequences[0][len(inputs.input_ids[0]):]
        print("\n--- GENERATED TOKEN IDS ---\n", gen_ids.tolist())
        
        # Decode without skipping specials
        raw_output = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
        print("\n--- RAW DECODED OUTPUT ---\n", repr(raw_output))
        
        # Show first generated token decoded individually
        if len(gen_ids) > 0:
            first_token = self.tokenizer.decode([gen_ids[0]])
            print(f"\n--- FIRST TOKEN --- '{first_token}' ---")
        
        # Clean as usual
        response = THINK_TAG_PATTERN.sub("", raw_output).strip()
        print("\n--- CLEANED RESPONSE ---\n", repr(response))
        

        return history + [(message, response)]


# Initialize chatbot
bot = ChatBot()

# Create interface
with gr.Blocks() as demo:
    gr.HTML('<h1 style="margin:0;">LLM dátheangach Gaeilge–Béarla forbartha ag Abair.ie</h1>')
    chatbot = gr.Chatbot(height=400)
    msg = gr.Textbox(placeholder="Type your message...", show_label=False)

    msg.submit(bot.chat, [msg, chatbot], [chatbot]).then(lambda: "", outputs=msg)

if __name__ == "__main__":
    demo.launch()