cochi1706 commited on
Commit
76de232
·
1 Parent(s): 376a746

Streamline model loading and response generation in chatbot application by utilizing a text generation pipeline. Removed legacy loading methods and improved response handling for enhanced performance and clarity.

Browse files
Files changed (1) hide show
  1. app.py +33 -126
app.py CHANGED
@@ -1,87 +1,30 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from peft import PeftModel
5
 
6
 
7
  # Load tokenizer và model
8
  print("Đang tải model...")
9
- import os
 
 
10
 
11
- base_model_name = "Qwen/Qwen3-0.6B"
12
- adapter_path_local = "./qwen3-finetuned"
13
- model_loaded = False
14
-
15
- # Ưu tiên 1: Thử load từ local path (nếu có)
16
- if os.path.exists(adapter_path_local) and os.path.exists(os.path.join(adapter_path_local, "adapter_config.json")):
17
- try:
18
- print(f"Đang load từ local path: {adapter_path_local}")
19
- base_model = AutoModelForCausalLM.from_pretrained(
20
- base_model_name,
21
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
22
- device_map="auto" if torch.cuda.is_available() else None,
23
- )
24
- model = PeftModel.from_pretrained(base_model, adapter_path_local)
25
- tokenizer = AutoTokenizer.from_pretrained(adapter_path_local, local_files_only=True)
26
- model_loaded = True
27
- print("✓ Đã load model từ local path")
28
- except Exception as e:
29
- print(f"✗ Không thể load từ local: {e}")
30
-
31
- # Ưu tiên 2: Thử load từ HuggingFace như full model
32
- if not model_loaded:
33
- try:
34
- model_name = "cochi1706/decoder/qwen3-finetuned"
35
- print(f"Đang thử load full model từ: {model_name}")
36
- tokenizer = AutoTokenizer.from_pretrained(model_name)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_name,
39
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
- device_map="auto" if torch.cuda.is_available() else None,
41
- )
42
- model_loaded = True
43
- print("✓ Đã load full model từ HuggingFace")
44
- except Exception as e:
45
- print(f"✗ Không thể load full model: {e}")
46
-
47
- # Ưu tiên 3: Load như PEFT adapter từ HuggingFace
48
- if not model_loaded:
49
- try:
50
- print("Đang load base model và PEFT adapter từ HuggingFace...")
51
- base_model = AutoModelForCausalLM.from_pretrained(
52
- base_model_name,
53
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
54
- device_map="auto" if torch.cuda.is_available() else None,
55
- )
56
- # Thử các adapter paths khác nhau
57
- adapter_paths = [
58
- "cochi1706/coding-assistant",
59
- "cochi1706/decoder/qwen3-finetuned",
60
- ]
61
- for adapter_path in adapter_paths:
62
- try:
63
- print(f" Thử adapter path: {adapter_path}")
64
- model = PeftModel.from_pretrained(base_model, adapter_path)
65
- tokenizer = AutoTokenizer.from_pretrained(adapter_path)
66
- model_loaded = True
67
- print(f"✓ Đã load PEFT adapter từ: {adapter_path}")
68
- break
69
- except Exception as e:
70
- print(f" ✗ Không thể load từ {adapter_path}: {e}")
71
- continue
72
- except Exception as e:
73
- print(f"✗ Không thể load base model: {e}")
74
-
75
- if not model_loaded:
76
- raise RuntimeError("Không thể load model từ bất kỳ nguồn nào. Vui lòng kiểm tra lại model path.")
77
-
78
- # Xác định device
79
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
 
81
  # Set padding token nếu chưa có
82
  if tokenizer.pad_token is None:
83
  tokenizer.pad_token = tokenizer.eos_token
84
 
 
 
 
 
 
 
 
 
 
85
  model.eval()
86
  print(f"Model đã sẵn sàng! Device: {device}")
87
 
@@ -95,7 +38,7 @@ def respond(
95
  top_p,
96
  ):
97
  """
98
- Tạo phản hồi từ model coding assistant
99
  """
100
  # Chuẩn bị prompt với chat template
101
  messages = [{"role": "system", "content": system_message}]
@@ -109,61 +52,25 @@ def respond(
109
  add_generation_prompt=True
110
  )
111
 
112
- # Tokenize
113
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
114
 
115
- # Di chuyển inputs đến device của model
116
- # Nếu model đã có device_map, lấy device từ model parameters
117
- if hasattr(model, 'hf_device_map') and model.hf_device_map:
118
- # Model đã được phân bổ trên nhiều device, sử dụng device của layer ��ầu tiên
119
- first_param_device = next(model.parameters()).device
120
- inputs = {k: v.to(first_param_device) for k, v in inputs.items()}
121
- else:
122
- # Model trên một device duy nhất
123
- inputs = {k: v.to(device) for k, v in inputs.items()}
124
 
125
- # Generate với streaming token-by-token
126
- input_length = inputs["input_ids"].shape[1]
127
- response = ""
128
 
129
- with torch.no_grad():
130
- # Khởi tạo với input_ids
131
- generated_ids = inputs["input_ids"].clone()
132
-
133
- for _ in range(max_tokens):
134
- # Forward pass
135
- outputs = model(generated_ids)
136
- logits = outputs.logits[:, -1, :]
137
-
138
- # Apply temperature và top_p
139
- if temperature != 1.0:
140
- logits = logits / temperature
141
-
142
- # Top-p sampling
143
- if top_p < 1.0:
144
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
145
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
146
- sorted_indices_to_remove = cumulative_probs > top_p
147
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
148
- sorted_indices_to_remove[..., 0] = 0
149
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
150
- logits[indices_to_remove] = float('-inf')
151
-
152
- # Sample next token
153
- probs = torch.softmax(logits, dim=-1)
154
- next_token = torch.multinomial(probs, num_samples=1)
155
-
156
- # Kiểm tra EOS token
157
- if next_token.item() == tokenizer.eos_token_id:
158
- break
159
-
160
- # Thêm token vào generated_ids
161
- generated_ids = torch.cat([generated_ids, next_token], dim=1)
162
-
163
- # Decode token mới và stream
164
- new_text = tokenizer.decode([next_token.item()], skip_special_tokens=True)
165
- response += new_text
166
- yield response
167
 
168
 
169
  """
@@ -180,7 +87,7 @@ chatbot = gr.ChatInterface(
180
  label="System message",
181
  lines=3,
182
  )
183
- ],
184
  )
185
 
186
  demo = chatbot
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
4
 
5
 
6
  # Load tokenizer và model
7
  print("Đang tải model...")
8
+ model_name = "cochi1706/codingassistant"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForCausalLM.from_pretrained(model_name)
11
 
12
+ # Xác định device cho pipeline (0 cho cuda, -1 cho cpu)
13
+ device = 0 if torch.cuda.is_available() else -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Set padding token nếu chưa có
16
  if tokenizer.pad_token is None:
17
  tokenizer.pad_token = tokenizer.eos_token
18
 
19
+ # Tạo pipeline để sinh text
20
+ text_generator = pipeline(
21
+ "text-generation",
22
+ model=model,
23
+ tokenizer=tokenizer,
24
+ device=device,
25
+ do_sample=True,
26
+ )
27
+
28
  model.eval()
29
  print(f"Model đã sẵn sàng! Device: {device}")
30
 
 
38
  top_p,
39
  ):
40
  """
41
+ Tạo phản hồi từ model coding assistant sử dụng pipeline
42
  """
43
  # Chuẩn bị prompt với chat template
44
  messages = [{"role": "system", "content": system_message}]
 
52
  add_generation_prompt=True
53
  )
54
 
55
+ # Sử dụng pipeline để generate text
56
+ generated = text_generator(
57
+ prompt,
58
+ max_length=len(tokenizer.encode(prompt)) + max_tokens,
59
+ max_new_tokens=max_tokens,
60
+ num_return_sequences=1,
61
+ temperature=temperature,
62
+ top_p=top_p,
63
+ do_sample=True,
64
+ )
65
 
66
+ # Lấy câu trả lời từ kết quả
67
+ câu_trả_lời = generated[0]['generated_text']
 
 
 
 
 
 
 
68
 
69
+ # Loại bỏ prompt ban đầu để chỉ lấy phần response
70
+ if prompt in câu_trả_lời:
71
+ câu_trả_lời = câu_trả_lời.replace(prompt, "").strip()
72
 
73
+ return câu_trả_lời
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
 
76
  """
 
87
  label="System message",
88
  lines=3,
89
  )
90
+ ]
91
  )
92
 
93
  demo = chatbot