|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
MODEL_ID = os.environ.get( |
|
|
"MODEL_ID", "Qwen/Qwen3-30B-A3B" |
|
|
) |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
LOAD_IN_4BIT = os.environ.get("LOAD_IN_4BIT", "false").lower() == "true" |
|
|
LOAD_IN_8BIT = os.environ.get("LOAD_IN_8BIT", "false").lower() == "true" |
|
|
|
|
|
|
|
|
if LOAD_IN_4BIT and LOAD_IN_8BIT: |
|
|
print( |
|
|
"Warning: Both LOAD_IN_4BIT and LOAD_IN_8BIT are set to true. Prioritizing 4-bit." |
|
|
) |
|
|
LOAD_IN_8BIT = False |
|
|
elif not LOAD_IN_4BIT and not LOAD_IN_8BIT: |
|
|
print( |
|
|
"Info: No explicit quantization (4-bit/8-bit) requested via environment variables. Loading in default precision (e.g., bfloat16 on GPU)." |
|
|
) |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
""" |
|
|
アプリケーション起動時にモデルとトークナイザーをロードする。 |
|
|
""" |
|
|
global model, tokenizer |
|
|
if model is None or tokenizer is None: |
|
|
quantization_info = "No Quantization" |
|
|
if LOAD_IN_4BIT: |
|
|
quantization_info = "4-bit Quantization" |
|
|
elif LOAD_IN_8BIT: |
|
|
quantization_info = "8-bit Quantization" |
|
|
|
|
|
print( |
|
|
f"Loading model: {MODEL_ID} on device: {DEVICE} with {quantization_info}..." |
|
|
) |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
model_kwargs = { |
|
|
"trust_remote_code": True |
|
|
} |
|
|
quantization_config = None |
|
|
if DEVICE == "cuda": |
|
|
model_kwargs["device_map"] = "auto" |
|
|
if LOAD_IN_4BIT: |
|
|
quantization_config = BitsAndBytesConfig(load_in_4bit=True) |
|
|
model_kwargs["torch_dtype"] = "auto" |
|
|
|
|
|
elif LOAD_IN_8BIT: |
|
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True) |
|
|
|
|
|
else: |
|
|
model_kwargs["torch_dtype"] = torch.bfloat16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
if LOAD_IN_4BIT or LOAD_IN_8BIT: |
|
|
print( |
|
|
"Warning: bitsandbytes quantization (4-bit/8-bit) is primarily for GPU. Attempting on CPU may be slow or unstable." |
|
|
) |
|
|
|
|
|
pass |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, **model_kwargs, quantization_config=quantization_config |
|
|
) |
|
|
|
|
|
if DEVICE == "cpu" and not ( |
|
|
LOAD_IN_4BIT or LOAD_IN_8BIT |
|
|
): |
|
|
model = model.to(DEVICE) |
|
|
|
|
|
model.eval() |
|
|
print(f"Model {MODEL_ID} loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"Error loading model {MODEL_ID}: {e}") |
|
|
|
|
|
|
|
|
raise RuntimeError(f"Failed to load model: {e}") |
|
|
|
|
|
|
|
|
def generate_text( |
|
|
prompt: str, |
|
|
max_new_tokens: int = 100, |
|
|
temperature: float = 0.3, |
|
|
top_p: float = 0.9, |
|
|
repetition_penalty: float = 1.0, |
|
|
) -> str: |
|
|
""" |
|
|
ロードされたモデルを使ってテキストを生成する。 |
|
|
""" |
|
|
if model is None or tokenizer is None: |
|
|
raise RuntimeError("Model not loaded. Cannot generate text.") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
|
|
try: |
|
|
|
|
|
prompt_formatted = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
|
|
|
|
|
|
) |
|
|
except Exception: |
|
|
|
|
|
|
|
|
print( |
|
|
f"Warning: tokenizer.apply_chat_template failed for {MODEL_ID}. Using raw prompt or basic formatting." |
|
|
) |
|
|
if ( |
|
|
"stablelm-instruct" in MODEL_ID.lower() or "elyza" in MODEL_ID.lower() |
|
|
): |
|
|
prompt_formatted = f"ユーザー: {prompt}\nシステム: " |
|
|
elif ( |
|
|
"qwen" in MODEL_ID.lower() and "chat" in MODEL_ID.lower() |
|
|
): |
|
|
prompt_formatted = ( |
|
|
f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" |
|
|
) |
|
|
else: |
|
|
prompt_formatted = prompt |
|
|
|
|
|
inputs = tokenizer( |
|
|
prompt_formatted, return_tensors="pt", add_special_tokens=False |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
|
|
|
if tokenizer.pad_token_id is None: |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
generation_kwargs = { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"repetition_penalty": repetition_penalty, |
|
|
"do_sample": True |
|
|
if temperature > 0 |
|
|
else False, |
|
|
"pad_token_id": tokenizer.pad_token_id, |
|
|
} |
|
|
|
|
|
outputs = model.generate(**inputs, **generation_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
output_text = tokenizer.decode( |
|
|
outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True |
|
|
) |
|
|
return output_text.strip() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during text generation: {e}") |
|
|
|
|
|
raise RuntimeError(f"Text generation failed: {e}") |
|
|
|