cortexa-ai / models /llm.py
VedNav1's picture
Merge branch 'main' of https://huggingface.co/spaces/Jay-10020/cortexa-ai
8ab7a71
"""
Language model for text generation
"""
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
pipeline
)
from typing import Optional
from config import LLM_MODEL, DEVICE, MODELS_DIR, MAX_NEW_TOKENS, TEMPERATURE, TOP_P
class LanguageModel:
def __init__(self):
<<<<<<< HEAD
print(f"Loading language model: {LLM_MODEL}")
# Quantization config for GPU (optional, only if you want smaller models)
quantization_config = None
# Only use quantization if on GPU
if DEVICE == "cuda":
try:
# Try 8-bit quantization (recommended)
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
print("Using 8-bit quantization")
except:
print("8-bit quantization not available, using full precision")
# Load tokenizer
=======
print(f"Loading local language model: {LLM_MODEL} on {DEVICE}...")
quantization_config = None
if DEVICE == "cuda":
try:
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_threshold=6.0)
except:
pass
>>>>>>> a06ff7e70b83069b439c95563bab4f3822d242b1
self.tokenizer = AutoTokenizer.from_pretrained(
LLM_MODEL,
cache_dir=str(MODELS_DIR),
trust_remote_code=True
)
<<<<<<< HEAD
# Set pad token if not set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model
=======
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
>>>>>>> a06ff7e70b83069b439c95563bab4f3822d242b1
self.model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL,
cache_dir=str(MODELS_DIR),
quantization_config=quantization_config,
device_map="auto" if DEVICE == "cuda" else None,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
trust_remote_code=True
)
if DEVICE == "cpu":
self.model = self.model.to(DEVICE)
self.model.eval()
<<<<<<< HEAD
print(f"✓ Language model loaded on {DEVICE}")
=======
print(f"✓ Local Language model loaded successfully!")
>>>>>>> a06ff7e70b83069b439c95563bab4f3822d242b1
def generate(
self,
prompt: str,
<<<<<<< HEAD
max_new_tokens: int = 150,
temperature: float = TEMPERATURE,
top_p: float = TOP_P
) -> str:
"""
Generate text from prompt using greedy decoding for speed.
"""
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=768, # cap input; single-MCQ prompts are <512 tokens
).to(self.model.device)
=======
max_new_tokens: int = MAX_NEW_TOKENS,
temperature: float = TEMPERATURE,
top_p: float = TOP_P
) -> str:
"""Generate text locally"""
# 👉 THE FIX: TinyLlama requires strict Chat Templates.
# Without this, it loops infinitely and causes timeouts.
is_chat = "chat" in LLM_MODEL.lower() or "instruct" in LLM_MODEL.lower()
if is_chat:
formatted_prompt = f"<|system|>\nYou are an expert educational AI that strictly follows instructions.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n"
else:
formatted_prompt = prompt
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
print(f"🤖 Generating locally (up to {max_new_tokens} tokens). This may take 2-4 minutes on a free CPU Space...")
>>>>>>> a06ff7e70b83069b439c95563bab4f3822d242b1
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
<<<<<<< HEAD
max_length=None,
do_sample=False, # greedy — ~3x faster than sampling
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1, # avoid repetition loops
)
# Decode only the newly generated tokens (skip input)
input_len = inputs["input_ids"].shape[1]
generated_ids = outputs[0][input_len:]
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
return generated_text.strip()
=======
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up the output to return ONLY the generated MCQs
clean_input = self.tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)
if generated_text.startswith(clean_input):
generated_text = generated_text[len(clean_input):].strip()
if "<|assistant|>" in generated_text:
generated_text = generated_text.split("<|assistant|>")[-1].strip()
return generated_text
>>>>>>> a06ff7e70b83069b439c95563bab4f3822d242b1
# Singleton instance
_llm_model = None
def get_llm_model() -> LanguageModel:
<<<<<<< HEAD
"""Get or create LLM instance"""
global _llm_model
if _llm_model is None:
_llm_model = LanguageModel()
return _llm_model
=======
global _llm_model
if _llm_model is None:
_llm_model = LanguageModel()
return _llm_model
>>>>>>> a06ff7e70b83069b439c95563bab4f3822d242b1