File size: 2,226 Bytes
d40385d b4b9b88 8c02943 d40385d b4b9b88 d40385d 8c02943 b4b9b88 d40385d 8c02943 b4b9b88 8c02943 b4b9b88 8c02943 b4b9b88 d40385d b4b9b88 8c02943 b4b9b88 8c02943 b4b9b88 8c02943 b4b9b88 8c02943 b4b9b88 8c02943 b4b9b88 d40385d 8c02943 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
LORA_REPO = "Doanlol/qwen25-3b-van-lora"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, LORA_REPO)
model.eval()
SYSTEM_PROMPT = "Bạn là trợ lý viết văn tiếng Việt, lập luận rõ ràng, cảm xúc, đúng trọng tâm đề."
def generate_essay(prompt, max_new_tokens, temperature, top_p):
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=1.05,
eos_token_id=tokenizer.eos_token_id,
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = decoded[len(text):].strip() if decoded.startswith(text) else decoded
return answer
demo = gr.Interface(
fn=generate_essay,
inputs=[
gr.Textbox(lines=8, label="Nhập đề văn / yêu cầu"),
gr.Slider(128, 1024, value=512, step=32, label="max_new_tokens"),
gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="temperature"),
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"),
],
outputs=gr.Textbox(lines=16, label="Bài làm"),
title="Qwen2.5-3B Văn AI (LoRA)",
description="Sinh bài văn tiếng Việt từ model LoRA đã fine-tune.",
)
if __name__ == "__main__":
demo.launch() |