File size: 2,226 Bytes
d40385d
b4b9b88
8c02943
 
d40385d
b4b9b88
 
d40385d
8c02943
b4b9b88
 
d40385d
8c02943
 
 
b4b9b88
8c02943
 
 
b4b9b88
 
8c02943
b4b9b88
d40385d
b4b9b88
 
 
 
 
8c02943
b4b9b88
 
 
8c02943
 
 
 
 
 
b4b9b88
 
 
8c02943
 
b4b9b88
8c02943
 
b4b9b88
 
8c02943
 
b4b9b88
 
 
 
 
 
 
 
 
 
 
d40385d
 
 
8c02943
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
LORA_REPO = "Doanlol/qwen25-3b-van-lora"

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None,
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(base_model, LORA_REPO)
model.eval()

SYSTEM_PROMPT = "Bạn là trợ lý viết văn tiếng Việt, lập luận rõ ràng, cảm xúc, đúng trọng tâm đề."

def generate_essay(prompt, max_new_tokens, temperature, top_p):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": prompt},
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    inputs = tokenizer(text, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            repetition_penalty=1.05,
            eos_token_id=tokenizer.eos_token_id,
        )

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = decoded[len(text):].strip() if decoded.startswith(text) else decoded
    return answer

demo = gr.Interface(
    fn=generate_essay,
    inputs=[
        gr.Textbox(lines=8, label="Nhập đề văn / yêu cầu"),
        gr.Slider(128, 1024, value=512, step=32, label="max_new_tokens"),
        gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="temperature"),
        gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"),
    ],
    outputs=gr.Textbox(lines=16, label="Bài làm"),
    title="Qwen2.5-3B Văn AI (LoRA)",
    description="Sinh bài văn tiếng Việt từ model LoRA đã fine-tune.",
)

if __name__ == "__main__":
    demo.launch()