Doanlol commited on
Commit
b4b9b88
·
verified ·
1 Parent(s): 92b919d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -25
app.py CHANGED
@@ -1,56 +1,65 @@
1
- import torch
2
  import gradio as gr
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
 
6
- BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
7
- ADAPTER_ID = "Doanlol/qwen25-vietnamese-van-lora"
8
 
9
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
 
 
10
 
11
  base_model = AutoModelForCausalLM.from_pretrained(
12
  BASE_MODEL,
13
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
14
- device_map="auto",
15
  trust_remote_code=True,
16
  )
17
- model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
18
- model.eval()
19
 
20
- SYSTEM_PROMPT = "Bạn là trợ lý học văn tiếng Việt, trả lời rõ ràng, đúng trọng tâm, không bịa thông tin."
 
21
 
22
- def chat_fn(message, history):
23
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
24
- for user_msg, bot_msg in history:
25
- if user_msg:
26
- messages.append({"role": "user", "content": user_msg})
27
- if bot_msg:
28
- messages.append({"role": "assistant", "content": bot_msg})
29
- messages.append({"role": "user", "content": message})
30
 
 
 
 
 
 
31
  text = tokenizer.apply_chat_template(
32
- messages, tokenize=False, add_generation_prompt=True
 
 
33
  )
34
  inputs = tokenizer(text, return_tensors="pt").to(model.device)
35
 
36
  with torch.no_grad():
37
  outputs = model.generate(
38
  **inputs,
39
- max_new_tokens=300,
 
 
40
  do_sample=True,
41
- temperature=0.7,
42
- top_p=0.9,
43
  repetition_penalty=1.05,
 
44
  )
45
 
46
- full = tokenizer.decode(outputs[0], skip_special_tokens=True)
47
- answer = full[len(text):].strip() if full.startswith(text) else full
48
  return answer
49
 
50
- demo = gr.ChatInterface(
51
- fn=chat_fn,
52
- title="Qwen2.5 Vietnamese Văn học Assistant",
53
- description="Fine-tuned LoRA model by Doanlol",
 
 
 
 
 
 
 
54
  )
55
 
56
  if __name__ == "__main__":
 
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
 
6
+ BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
7
+ LORA_REPO = "Doanlol/qwen25-3b-van-lora"
8
 
9
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
10
+ if tokenizer.pad_token is None:
11
+ tokenizer.pad_token = tokenizer.eos_token
12
 
13
  base_model = AutoModelForCausalLM.from_pretrained(
14
  BASE_MODEL,
15
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
16
+ device_map="auto" if torch.cuda.is_available() else None,
17
  trust_remote_code=True,
18
  )
 
 
19
 
20
+ model = PeftModel.from_pretrained(base_model, LORA_REPO)
21
+ model.eval()
22
 
23
+ SYSTEM_PROMPT = "Bạn là trợ lý viết văn tiếng Việt, lập luận rõ ràng, cảm xúc, đúng trọng tâm đề."
 
 
 
 
 
 
 
24
 
25
+ def generate_essay(prompt, max_new_tokens, temperature, top_p):
26
+ messages = [
27
+ {"role": "system", "content": SYSTEM_PROMPT},
28
+ {"role": "user", "content": prompt},
29
+ ]
30
  text = tokenizer.apply_chat_template(
31
+ messages,
32
+ tokenize=False,
33
+ add_generation_prompt=True
34
  )
35
  inputs = tokenizer(text, return_tensors="pt").to(model.device)
36
 
37
  with torch.no_grad():
38
  outputs = model.generate(
39
  **inputs,
40
+ max_new_tokens=max_new_tokens,
41
+ temperature=temperature,
42
+ top_p=top_p,
43
  do_sample=True,
 
 
44
  repetition_penalty=1.05,
45
+ eos_token_id=tokenizer.eos_token_id,
46
  )
47
 
48
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+ answer = decoded[len(text):].strip() if decoded.startswith(text) else decoded
50
  return answer
51
 
52
+ demo = gr.Interface(
53
+ fn=generate_essay,
54
+ inputs=[
55
+ gr.Textbox(lines=8, label="Nhập đề văn / yêu cầu"),
56
+ gr.Slider(128, 1024, value=512, step=32, label="max_new_tokens"),
57
+ gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="temperature"),
58
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"),
59
+ ],
60
+ outputs=gr.Textbox(lines=16, label="Bài làm"),
61
+ title="Qwen2.5-3B Văn AI (LoRA)",
62
+ description="Sinh bài văn tiếng Việt từ model LoRA đã fine-tune.",
63
  )
64
 
65
  if __name__ == "__main__":