Mistral_Test / optimization.py
eesfeg's picture
Add application file
1e639fb
raw
history blame contribute delete
907 Bytes
# optimization.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# 1. Use pipeline for simplicity
pipe = pipeline(
"text-generation",
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
model_kwargs={
"torch_dtype": torch.float16,
"device_map": "auto",
"load_in_4bit": True
},
tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
)
# 2. Use vLLM for high-throughput (install: pip install vLLM)
from vllm import LLM, SamplingParams
llm = LLM(model="mTinyLlama/TinyLlama-1.1B-Chat-v1.0")
sampling_params = SamplingParams(temperature=0.7, max_tokens=500)
outputs = llm.generate(["Hello, how are you?"], sampling_params)
# 3. Cache model responses
import hashlib
from functools import lru_cache
@lru_cache(maxsize=1000)
def cached_generation(prompt, max_tokens=500):
return pipe(prompt, max_new_tokens=max_tokens)[0]['generated_text']