import warnings warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from threading import Thread MODEL_ID = "ecomindia/ecom-test" SYSTEM = "You are ECOM bot, an expert assistant for ECOM's Android devices." print(f"Gradio version: {gr.__version__}") # ── Device ──────────────────────────────────────────────────── if torch.cuda.is_available(): DTYPE = torch.float16 DEVICE_MAP = "auto" print(f"GPU: {torch.cuda.get_device_name(0)}") else: DTYPE = torch.float32 DEVICE_MAP = None print("CPU mode") # ── Tokenizer ───────────────────────────────────────────────── print("Loading tokenizer...") tok = AutoTokenizer.from_pretrained( MODEL_ID, use_fast=True, trust_remote_code=True, ) if tok.pad_token is None: tok.pad_token = tok.eos_token tok.pad_token_id = tok.eos_token_id # ── Model ───────────────────────────────────────────────────── print("Loading model...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=DTYPE, device_map=DEVICE_MAP, trust_remote_code=True, low_cpu_mem_usage=True, ) if DEVICE_MAP is None: model = model.to("cpu") model.eval() print("Ready.") # ── Helper: normalize content to plain string ───────────────── def to_str(content): if isinstance(content, str): return content if isinstance(content, list): parts = [] for item in content: if isinstance(item, dict): parts.append(item.get("text", "")) else: parts.append(str(item)) return "".join(parts) return str(content) # ── Generation ──────────────────────────────────────────────── def respond(message, history, max_tokens, temperature): messages = [{"role": "system", "content": SYSTEM}] for entry in history: role = entry.get("role", "user") content = to_str(entry.get("content", "")) if content: messages.append({"role": role, "content": content}) messages.append({"role": "user", "content": to_str(message)}) # apply_chat_template can return either a plain tensor or a # BatchEncoding dict depending on the transformers version. # Always call tok() directly on the rendered string to get a # guaranteed plain tensor with no ambiguity. prompt = tok.apply_chat_template( messages, tokenize=False, # render to string first add_generation_prompt=True, ) encoded = tok(prompt, return_tensors="pt") input_ids = encoded["input_ids"].to(model.device) streamer = TextIteratorStreamer( tok, skip_prompt=True, skip_special_tokens=True, ) Thread(target=model.generate, kwargs=dict( input_ids=input_ids, streamer=streamer, max_new_tokens=int(max_tokens), temperature=float(temperature), do_sample=temperature > 0, repetition_penalty=1.1, )).start() partial = "" for token in streamer: partial += token yield partial # ── UI ──────────────────────────────────────────────────────── with gr.Blocks(title="ECOM AI Agent") as demo: gr.Markdown("## ECOM AI agent — Your Product Assistant") chatbot = gr.Chatbot(height=450) msg = gr.Textbox(placeholder="Ask a question...", label="Your message") with gr.Row(): max_tokens = gr.Slider(64, 1024, value=512, step=64, label="Max tokens") temperature = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Temperature") with gr.Row(): submit_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear") gr.Examples( examples=[ "How do I upgrade my plan?", "What happens if my payment fails?", "How do I reset my API key?", ], inputs=msg, ) def user_turn(user_message, history): history = history or [] history.append({"role": "user", "content": user_message}) history.append({"role": "assistant", "content": ""}) return "", history def bot_turn(history, max_tok, temp): user_message = "" for entry in reversed(history): if entry.get("role") == "user": user_message = to_str(entry.get("content", "")) break for chunk in respond(user_message, history[:-1], max_tok, temp): history[-1]["content"] = chunk yield history msg.submit(user_turn, [msg, chatbot], [msg, chatbot]).then( bot_turn, [chatbot, max_tokens, temperature], chatbot ) submit_btn.click(user_turn, [msg, chatbot], [msg, chatbot]).then( bot_turn, [chatbot, max_tokens, temperature], chatbot ) clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg]) if __name__ == "__main__": demo.launch()