| import warnings |
| warnings.filterwarnings("ignore", category=FutureWarning) |
| warnings.filterwarnings("ignore", category=UserWarning) |
|
|
| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
| from threading import Thread |
|
|
| MODEL_ID = "ecomindia/ecom-test" |
| SYSTEM = "You are ECOM bot, an expert assistant for ECOM's Android devices." |
|
|
| print(f"Gradio version: {gr.__version__}") |
|
|
| |
| if torch.cuda.is_available(): |
| DTYPE = torch.float16 |
| DEVICE_MAP = "auto" |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| else: |
| DTYPE = torch.float32 |
| DEVICE_MAP = None |
| print("CPU mode") |
|
|
| |
| print("Loading tokenizer...") |
| tok = AutoTokenizer.from_pretrained( |
| MODEL_ID, use_fast=True, trust_remote_code=True, |
| ) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| tok.pad_token_id = tok.eos_token_id |
|
|
| |
| print("Loading model...") |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=DTYPE, |
| device_map=DEVICE_MAP, |
| trust_remote_code=True, |
| low_cpu_mem_usage=True, |
| ) |
| if DEVICE_MAP is None: |
| model = model.to("cpu") |
| model.eval() |
| print("Ready.") |
|
|
| |
| def to_str(content): |
| if isinstance(content, str): |
| return content |
| if isinstance(content, list): |
| parts = [] |
| for item in content: |
| if isinstance(item, dict): |
| parts.append(item.get("text", "")) |
| else: |
| parts.append(str(item)) |
| return "".join(parts) |
| return str(content) |
|
|
| |
| def respond(message, history, max_tokens, temperature): |
| messages = [{"role": "system", "content": SYSTEM}] |
| for entry in history: |
| role = entry.get("role", "user") |
| content = to_str(entry.get("content", "")) |
| if content: |
| messages.append({"role": role, "content": content}) |
| messages.append({"role": "user", "content": to_str(message)}) |
|
|
| |
| |
| |
| |
| prompt = tok.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| encoded = tok(prompt, return_tensors="pt") |
| input_ids = encoded["input_ids"].to(model.device) |
|
|
| streamer = TextIteratorStreamer( |
| tok, skip_prompt=True, skip_special_tokens=True, |
| ) |
| Thread(target=model.generate, kwargs=dict( |
| input_ids=input_ids, |
| streamer=streamer, |
| max_new_tokens=int(max_tokens), |
| temperature=float(temperature), |
| do_sample=temperature > 0, |
| repetition_penalty=1.1, |
| )).start() |
|
|
| partial = "" |
| for token in streamer: |
| partial += token |
| yield partial |
|
|
| |
| with gr.Blocks(title="ECOM AI Agent") as demo: |
| gr.Markdown("## ECOM AI agent β Your Product Assistant") |
|
|
| chatbot = gr.Chatbot(height=450) |
| msg = gr.Textbox(placeholder="Ask a question...", label="Your message") |
|
|
| with gr.Row(): |
| max_tokens = gr.Slider(64, 1024, value=512, step=64, label="Max tokens") |
| temperature = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Temperature") |
|
|
| with gr.Row(): |
| submit_btn = gr.Button("Send", variant="primary") |
| clear_btn = gr.Button("Clear") |
|
|
| gr.Examples( |
| examples=[ |
| "How do I upgrade my plan?", |
| "What happens if my payment fails?", |
| "How do I reset my API key?", |
| ], |
| inputs=msg, |
| ) |
|
|
| def user_turn(user_message, history): |
| history = history or [] |
| history.append({"role": "user", "content": user_message}) |
| history.append({"role": "assistant", "content": ""}) |
| return "", history |
|
|
| def bot_turn(history, max_tok, temp): |
| user_message = "" |
| for entry in reversed(history): |
| if entry.get("role") == "user": |
| user_message = to_str(entry.get("content", "")) |
| break |
| for chunk in respond(user_message, history[:-1], max_tok, temp): |
| history[-1]["content"] = chunk |
| yield history |
|
|
| msg.submit(user_turn, [msg, chatbot], [msg, chatbot]).then( |
| bot_turn, [chatbot, max_tokens, temperature], chatbot |
| ) |
| submit_btn.click(user_turn, [msg, chatbot], [msg, chatbot]).then( |
| bot_turn, [chatbot, max_tokens, temperature], chatbot |
| ) |
| clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg]) |
|
|
| if __name__ == "__main__": |
| demo.launch() |