ayush2917's picture
Update models/llm.py
d94afa9 verified
# models/llm.py
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_NAME = "microsoft/phi-2"
@st.cache_resource(show_spinner="Loading Phi-2 model...")
def load_llm():
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32 # CPU safe
)
model.eval()
return tokenizer, model
def generate(
prompt: str,
max_new_tokens: int = 128,
temperature: float = 0.2
) -> str:
tokenizer, model = load_llm()
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048
)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)