import streamlit as st import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import os from threading import Thread import time # Load Model and Tokenizer token = os.environ.get("HF_TOKEN") model_name = "large-traversaal/Phi-4-Hindi" @st.cache_resource() def load_model(): model = AutoModelForCausalLM.from_pretrained( model_name, token=token, trust_remote_code=True, torch_dtype=torch.bfloat16 ) tok = AutoTokenizer.from_pretrained(model_name, token=token) return model, tok model, tok = load_model() terminators = [tok.eos_token_id] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Initialize session state if not set if "chat_history" not in st.session_state: st.session_state.chat_history = [] # Chat function def chat(message, temperature, do_sample, max_tokens): chat_log = st.session_state.chat_history.copy() chat_log.append({"role": "user", "content": message}) messages = tok.apply_chat_template(chat_log, tokenize=False, add_generation_prompt=True) model_inputs = tok([messages], return_tensors="pt").to(device) streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "inputs": model_inputs["input_ids"], "streamer": streamer, "max_new_tokens": max_tokens, "do_sample": do_sample, "temperature": temperature, "eos_token_id": terminators, } if temperature == 0: generate_kwargs["do_sample"] = False t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() partial_text = "" for new_text in streamer: partial_text += new_text yield partial_text st.session_state.chat_history.append({"role": "assistant", "content": partial_text}) # Streamlit UI st.title("๐ฌ Chat With Phi-4-Hindi") st.markdown("Chat with [large-traversaal/Phi-4-Hindi](https://huggingface.co/large-traversaal/Phi-4-Hindi)") # Chat input temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3, 0.1) do_sample = st.sidebar.checkbox("Use Sampling", value=True) max_tokens = st.sidebar.slider("Max Tokens", 128, 4096, 512, 1) text_color = st.sidebar.selectbox("Text Color", ["Red", "Black", "Blue", "Green", "Purple"], index=0) dark_mode = st.sidebar.checkbox("๐ Dark Mode", value=False) def get_html_text(text, color): return f'
{text}
' for msg in st.session_state.chat_history: if msg["role"] == "user": st.markdown(get_html_text("๐ค " + msg["content"], "black"), unsafe_allow_html=True) else: st.markdown(get_html_text("๐ค " + msg["content"], text_color), unsafe_allow_html=True) user_input = st.text_input("Type your message:", "") if st.button("Send"): if user_input.strip(): st.session_state.chat_history.append({"role": "user", "content": user_input}) with st.spinner("Generating response..."): for output in chat(user_input, temperature, do_sample, max_tokens): pass st.experimental_rerun() if st.button("๐งน Clear Chat"): st.session_state.chat_history = [] st.experimental_rerun()