KIMI / app.py
Sachin23991's picture
Update app.py
899b2f7 verified
import gradio as gr
from huggingface_hub import InferenceClient
import os
HF_TOKEN = os.environ.get("HF_TOKEN")
# Free-ish models that usually work on HF Inference API
MODELS = {
"Llama 3 8B": "meta-llama/Meta-Llama-3-8B-Instruct",
"Mistral 7B": "mistralai/Mistral-7B-Instruct-v0.2",
}
client = InferenceClient(token=HF_TOKEN)
def approximate_token_count(text):
# crude estimation (1 token β‰ˆ 4 chars)
return len(text) // 4
def chat_fn(message, history, model_name, max_tokens, temperature):
if not HF_TOKEN:
yield "HF_TOKEN not set in Space Secrets."
return
model_id = MODELS[model_name]
messages = []
for item in history:
if isinstance(item, (list, tuple)) and len(item) == 2:
user_msg, bot_msg = item
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
# πŸ”’ Basic token safety
total_text = " ".join([m["content"] for m in messages])
if approximate_token_count(total_text) > 3500:
yield "Conversation too long. Please refresh."
return
try:
stream = client.chat.completions.create(
model=model_id,
messages=messages,
max_tokens=int(max_tokens),
temperature=float(temperature),
stream=True
)
partial = ""
for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
partial += chunk.choices[0].delta.content
yield partial
except Exception as e:
yield f"Inference Error:\n{str(e)}"
with gr.Blocks() as demo:
gr.Markdown("## Free Multi-Model Chat (HF Inference API)")
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="Llama 3 8B",
label="Select Model"
)
max_tokens_slider = gr.Slider(
64, 2048,
value=512,
step=64,
label="Max Tokens"
)
temperature_slider = gr.Slider(
0.1, 1.5,
value=0.7,
step=0.1,
label="Temperature"
)
chat = gr.ChatInterface(
fn=chat_fn,
additional_inputs=[model_dropdown, max_tokens_slider, temperature_slider],
chatbot=gr.Chatbot(height=400)
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False
)