File size: 1,753 Bytes
441e4e3
 
 
 
 
 
 
 
154d3ef
441e4e3
 
 
 
 
 
 
154d3ef
441e4e3
 
 
154d3ef
 
441e4e3
154d3ef
 
 
 
 
 
 
 
441e4e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import torch
import transformers
from transformers import pipeline

model_path = "unsloth/gemma-3-4b-it-unsloth-bnb-4bit"

# إذا كان فيه HF_TOKEN في البيئة
hf_token = os.getenv("HF_TOKEN")

print("Loading model...")
try:
    # Initialize pipeline for chat
    # For quantized models, use device=0 instead of device_map="auto" to avoid meta tensor issues
    pipeline_model = pipeline(
        "text-generation",
        model=model_path,
        device=0,  # Use GPU device directly
        torch_dtype=torch.bfloat16,
        token=hf_token,
        trust_remote_code=True,
        model_kwargs={
            "torch_dtype": torch.bfloat16,
            "load_in_4bit": True,
            "bnb_4bit_compute_dtype": torch.bfloat16,
            "bnb_4bit_use_double_quant": False,
            "bnb_4bit_quant_type": "nf4",
        }
    )

    print("Model loaded successfully!")

    # Test with a simple message
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Hello!"},
    ]

    print("Testing generation...")
    # Apply chat template for unsloth models
    prompt = pipeline_model.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    outputs = pipeline_model(
        prompt,
        max_new_tokens=50,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        return_full_text=False
    )

    response = outputs[0]["generated_text"]
    print(f"Test response: {response}")
    print("✅ Model test successful!")

except Exception as e:
    print(f"❌ Error: {e}")
    import traceback
    traceback.print_exc()