Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| import json | |
| import re | |
| import torch | |
| # Load tokenizer và model trên cloud HF | |
| model_name = "meetkai/functionary-small-v3.1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Config quantization để load model nhẹ (4bit, giảm time/memory trên CPU) | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, # 4bit để nhanh nhất trên free tier | |
| bnb_4bit_quant_type="nf4", # Loại quant tốt cho accuracy | |
| bnb_4bit_compute_dtype=torch.float16, # Compute ở float16 | |
| bnb_4bit_use_double_quant=True # Quant lồng để tiết kiệm thêm | |
| ) | |
| # Load model với quant và low mem | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=quant_config, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| def chat_with_tools(messages_str, tools_str): | |
| try: | |
| # Parse inputs | |
| messages = json.loads(messages_str) | |
| tools = json.loads(tools_str) | |
| # Build prompt với tools (trả về tensor) | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tools=tools, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ) | |
| # Move prompt lên device (default CPU) | |
| prompt = prompt.to(model.device) | |
| # Generate thủ công (nhanh hơn với quant) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| prompt, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode output (bỏ prompt input) | |
| generated = tokenizer.decode(outputs[0][prompt.shape[1]:], skip_special_tokens=True) | |
| # Parse tool call nếu có | |
| match = re.search(r'<function=(\w+)>(.*?)</function>', generated, re.DOTALL) | |
| if match: | |
| func_name = match.group(1) | |
| args = json.loads(match.group(2)) | |
| return {"tool_call": {"name": func_name, "arguments": args}, "content": generated} | |
| return {"content": generated} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=chat_with_tools, | |
| inputs=[gr.Textbox(label="Messages JSON"), gr.Textbox(label="Tools JSON")], | |
| outputs="json" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |