Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Load model | |
| model_name = "microsoft/DialoGPT-medium" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| # Chat function | |
| def reply(message, history): | |
| if not message.strip(): | |
| return "Please enter a message." | |
| chat_history_ids = None | |
| # Handle previous conversation | |
| for msg in history: | |
| content = msg["content"] | |
| # Fix: if content is list → convert to string | |
| if isinstance(content, list): | |
| content = " ".join([str(x) for x in content]) | |
| if not isinstance(content, str): | |
| continue | |
| ids = tokenizer.encode(content + tokenizer.eos_token, return_tensors="pt") | |
| if chat_history_ids is None: | |
| chat_history_ids = ids | |
| else: | |
| chat_history_ids = torch.cat([chat_history_ids, ids], dim=-1) | |
| # Current message | |
| new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt") | |
| if chat_history_ids is not None: | |
| input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) | |
| else: | |
| input_ids = new_input_ids | |
| # Attention mask fix | |
| attention_mask = torch.ones_like(input_ids) | |
| # Generate response | |
| output_ids = model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_length=1000, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95, | |
| temperature=0.7, | |
| repetition_penalty=1.2 | |
| ) | |
| # Extract only new response | |
| response_ids = output_ids[:, input_ids.shape[-1]:] | |
| response = tokenizer.decode(response_ids[0], skip_special_tokens=True) | |
| # Fallback | |
| if response.strip() == "": | |
| response = "I'm here! How can I help you?" | |
| return response | |
| # UI | |
| demo = gr.ChatInterface( | |
| fn=reply, | |
| title="💬 Smart Dialogue System", | |
| description="Full conversation chatbot using DialoGPT" | |
| ) | |
| # Launch | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ssr_mode=False | |
| ) |