Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import os | |
| from threading import Thread | |
| import requests | |
| import time | |
| # Define model path for caching (Avoids reloading every app restart) | |
| MODEL_PATH = "/mnt/data/Phi-4-Hindi" | |
| TOKEN = os.environ.get("HF_TOKEN") | |
| MODEL_NAME = "DrishtiSharma/Phi-4-Hindi-quantized" | |
| # Load Model & Tokenizer Once | |
| def load_model(): | |
| with st.spinner("Loading model... Please wait β³"): | |
| try: | |
| if not os.path.exists(MODEL_PATH): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, token=TOKEN, trust_remote_code=True, torch_dtype=torch.bfloat16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=TOKEN) | |
| model.save_pretrained(MODEL_PATH) | |
| tokenizer.save_pretrained(MODEL_PATH) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| except requests.exceptions.ConnectionError: | |
| st.error("β οΈ Connection error! Unable to download the model. Please check your internet connection and try again.") | |
| return None, None | |
| except requests.exceptions.ReadTimeout: | |
| st.error("β οΈ Read Timeout! The request took too long. Please try again later.") | |
| return None, None | |
| return model, tokenizer | |
| # Load and move model to appropriate device | |
| model, tok = load_model() | |
| if model is None or tok is None: | |
| st.stop() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| try: | |
| model = model.to(device) | |
| except torch.cuda.OutOfMemoryError: | |
| st.error("β οΈ CUDA Out of Memory! Running on CPU instead.") | |
| device = torch.device("cpu") | |
| model = model.to(device) | |
| terminators = [tok.eos_token_id] | |
| # Initialize session state if not set | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # Chat function | |
| def chat(message, temperature, do_sample, max_tokens): | |
| """Processes chat input and generates a response using the model.""" | |
| # Append new message to history | |
| st.session_state.chat_history.append({"role": "user", "content": message}) | |
| # Convert chat history into model-friendly format | |
| messages = tok.apply_chat_template(st.session_state.chat_history, tokenize=False, add_generation_prompt=True) | |
| model_inputs = tok([messages], return_tensors="pt").to(device) | |
| # Initialize streamer for token-wise response | |
| streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
| # Define generation parameters | |
| generate_kwargs = { | |
| "inputs": model_inputs["input_ids"], | |
| "streamer": streamer, | |
| "max_new_tokens": max_tokens, | |
| "do_sample": do_sample, | |
| "temperature": temperature, | |
| "eos_token_id": terminators, | |
| } | |
| if temperature == 0: | |
| generate_kwargs["do_sample"] = False | |
| # Generate response asynchronously | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| # Collect response as it streams | |
| response_text = "" | |
| for new_text in streamer: | |
| response_text += new_text | |
| yield response_text | |
| # Save the assistant's response to session history | |
| st.session_state.chat_history.append({"role": "assistant", "content": response_text}) | |
| # UI Setup | |
| st.title("π¬ Chat With Phi-4-Hindi") | |
| st.success("β Model is READY to chat!") | |
| st.markdown("Chat with [large-traversaal/Phi-4-Hindi](https://huggingface.co/large-traversaal/Phi-4-Hindi)") | |
| # Sidebar Chat Settings | |
| temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3, 0.1) | |
| do_sample = st.sidebar.checkbox("Use Sampling", value=True) | |
| max_tokens = st.sidebar.slider("Max Tokens", 128, 4096, 512, 1) | |
| text_color = st.sidebar.selectbox("Text Color", ["Red", "Black", "Blue", "Green", "Purple"], index=0) | |
| dark_mode = st.sidebar.checkbox("π Dark Mode", value=False) | |
| # Function to format chat messages | |
| def get_html_text(text, color): | |
| return f'<p style="color: {color.lower()}; font-size: 16px;">{text}</p>' | |
| # Display chat history | |
| for msg in st.session_state.chat_history: | |
| role = "π€" if msg["role"] == "user" else "π€" | |
| st.markdown(get_html_text(f"**{role}:** {msg['content']}", text_color if role == "π€" else "black"), unsafe_allow_html=True) | |
| # User Input Handling | |
| user_input = st.text_input("Type your message:", "") | |
| if st.button("Send"): | |
| if user_input.strip(): | |
| st.session_state.chat_history.append({"role": "user", "content": user_input}) | |
| # Display chatbot response | |
| with st.spinner("Generating response... π€π"): | |
| response_generator = chat(user_input, temperature, do_sample, max_tokens) | |
| final_response = "" | |
| for output in response_generator: | |
| final_response = output # Store latest output | |
| st.success("β Response generated!") | |
| # Add generated response to session state | |
| st.rerun() | |
| if st.button("π§Ή Clear Chat"): | |
| with st.spinner("Clearing chat history..."): | |
| st.session_state.chat_history = [] | |
| st.success("β Chat history cleared!") | |
| st.rerun() | |