test-api / main.py
Mr-Help's picture
Update main.py
6cef2a4 verified
raw
history blame
1.76 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
def main():
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Pick dtype ู…ู†ุงุณุจ: bfloat16 ู„ูˆ GPU ู…ุชุงุญุŒ ุบูŠุฑ ูƒุฏู‡ float32 ุนู„ู‰ CPU
has_cuda = torch.cuda.is_available()
dtype = torch.bfloat16 if has_cuda else torch.float32
# Load model (device_map="auto" ูŠูˆุฒุน ุชู„ู‚ุงุฆูŠ)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map="auto"
)
# Prompt: explain Past Simple in simple English
messages = [
{"role": "system", "content": "You are a friendly English teacher. Explain clearly and simply."},
{"role": "user", "content": "Explain the Past Simple tense in very simple English. Give rules and 8 short examples. Keep it clear for A2 learners."}
]
# Convert chat messages to model input
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
generated_ids = model.generate(
**model_inputs,
max_new_tokens=400,
do_sample=True,
temperature=0.7,
top_p=0.9
)
# Keep only the newly generated tokens (remove the prompt tokens)
new_tokens = generated_ids[0, model_inputs["input_ids"].shape[-1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
print("\n=== Model Response ===\n")
print(response.strip())
print("\n======================\n")
if __name__ == "__main__":
main()