| import os |
| import streamlit as st |
| import torch |
| from transformers import ( |
| pipeline, |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| BitsAndBytesConfig, |
| ) |
| from peft import LoraConfig, get_peft_model |
| from safetensors.torch import load_file as safe_load |
|
|
| |
| MODEL_REPO = "models/phi2-deeptalk-lora" |
| BASE_MODEL = "microsoft/phi-2" |
| CONTEXT_TURNS = 7 |
| MAX_NEW_TOKENS = 32 |
| TEMPERATURE = 0.0 |
| TOP_P = 1.0 |
| DEVICE_MAP = "auto" |
|
|
| SYSTEM = ( |
| "You are a helpful assistant for DeepTalks with a base model Phi-2 " |
| "fine-tuned by Sourish for domain-specific support.\n" |
| "Base replies **only** on the context below. " |
| "If you don't know, say βI don't know.β\n" |
| ) |
|
|
| |
| @st.cache_resource(show_spinner=False) |
| def load_generator(): |
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| BASE_MODEL, |
| trust_remote_code=True, |
| padding_side="left", |
| ) |
| if tokenizer.pad_token_id is None: |
| tokenizer.add_special_tokens({"pad_token":"[PAD]"}) |
|
|
| |
| if torch.cuda.is_available(): |
| bnb = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype="float16", |
| low_cpu_mem_usage=True, |
| ) |
| base = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL, |
| trust_remote_code=True, |
| quantization_config=bnb, |
| device_map=DEVICE_MAP, |
| ) |
| else: |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| base = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL, |
| trust_remote_code=True, |
| torch_dtype=dtype, |
| device_map=DEVICE_MAP, |
| ) |
|
|
| |
| base.resize_token_embeddings(len(tokenizer)) |
| peft_config = LoraConfig.from_pretrained(MODEL_REPO, local_files_only=True) |
| model = get_peft_model(base, peft_config) |
|
|
| |
| adapter_file = os.path.join(MODEL_REPO, "adapter_model.safetensors") |
| state_dict = safe_load(adapter_file) |
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
|
|
| |
| gen = pipeline( |
| "text-generation", |
| model=model, |
| tokenizer=tokenizer, |
| device_map=DEVICE_MAP, |
| max_new_tokens=MAX_NEW_TOKENS, |
| do_sample=False, |
| temperature=TEMPERATURE, |
| top_p=TOP_P, |
| use_cache=True, |
| return_full_text=False, |
| ) |
| return tokenizer, gen |
|
|
| tokenizer, generator = load_generator() |
|
|
| |
| st.set_page_config(layout="centered") |
| st.title("π§ Memory-Aware Phi-2 Chat") |
|
|
| |
| if "history" not in st.session_state: |
| st.session_state.history = [] |
|
|
| |
| for role, text in st.session_state.history: |
| st.chat_message("user" if role=="You" else "assistant").write(text) |
|
|
| |
| user_input = st.chat_input("Type your message...") |
|
|
| if user_input: |
| |
| st.chat_message("user").write(user_input) |
| st.session_state.history.append(("You", user_input)) |
|
|
| |
| recent = st.session_state.history[-CONTEXT_TURNS*2:] |
| ctx = "\n".join(f"{'User' if r=='You' else 'Assistant'}: {t}" |
| for r,t in recent) |
|
|
| prompt = f"{SYSTEM}\nContext:\n{ctx}\nUser: {user_input}\nAssistant:" |
|
|
| |
| with st.spinner("Thinking..."): |
| try: |
| out = generator(prompt)[0]["generated_text"].strip() |
| except Exception as e: |
| out = "Sorry, I encountered an error." |
| st.error(f"Generation error: {e}") |
|
|
| |
| st.chat_message("assistant").write(out) |
| st.session_state.history.append(("Bot", out)) |
|
|