Spaces:
Sleeping
Sleeping
File size: 5,286 Bytes
0919d5b 42a68b0 0919d5b 5e98df0 0919d5b 42a68b0 0919d5b 42a68b0 0919d5b 42a68b0 5e98df0 0919d5b 5e98df0 0919d5b 42a68b0 0919d5b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from typing import List, Dict, Optional
from rich.console import Console
console = Console()
class HuggingFaceChat:
"""Interface for chatting with Hugging Face models."""
def __init__(self, model_name: str = "microsoft/DialoGPT-medium"):
"""
Initialize the chat interface.
Args:
model_name: The Hugging Face model to use
"""
self.model_name = model_name
self.device = 0 if torch.cuda.is_available() else -1 # Use GPU if available
# Try loading the model with safetensors first
try:
console.print(f"[blue]Loading model: {model_name}[/blue]")
# Try to load with safetensors format first
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
self.model = AutoModelForCausalLM.from_pretrained(model_name, use_safetensors=True)
self.chatbot = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=self.device
)
console.print("[green]✓ Model loaded successfully[/green]")
except Exception as e:
console.print(f"[red]Error loading model with safetensors: {e}[/red]")
try:
# Fallback to regular loading
console.print("[yellow]Trying regular loading...[/yellow]")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.chatbot = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=self.device
)
console.print("[green]✓ Model loaded successfully[/green]")
except Exception as e2:
console.print(f"[red]Error loading model: {e2}[/red]")
console.print("[yellow]Falling back to simple text generation...[/yellow]")
self.chatbot = None
def generate_response(self, prompt: str, max_length: int = 1000) -> str:
"""
Generate a response to a prompt.
Args:
prompt: The input prompt
max_length: Maximum length of the generated response
Returns:
The generated response
"""
if not self.chatbot:
return "I'm sorry, but I couldn't load the AI model. This might be due to:\n1. Model loading issues\n2. Internet connection problems\n3. Server maintenance\n\nPlease try again in a few minutes."
try:
# Generate response with improved parameters for better quality
response = self.chatbot(
prompt,
max_length=max_length,
do_sample=True,
temperature=0.8, # Higher for more creativity
top_p=0.95, # Higher for more diverse responses
top_k=50, # Limit to top 50 tokens
repetition_penalty=1.1, # Lower penalty for more natural flow
no_repeat_ngram_size=3, # Avoid repeating 3-word phrases
pad_token_id=self.tokenizer.eos_token_id,
truncation=True # Enable truncation to prevent errors
)
# Extract the generated text
generated_text = response[0]['generated_text']
# Remove the prompt from the response if it's included
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):].strip()
# Clean up the response
# Remove any incomplete sentences or hanging punctuation
generated_text = self._clean_response(generated_text)
# If response is empty or too short, provide a helpful message
if not generated_text or len(generated_text.strip()) < 5:
return "I'm processing your message. Could you please try again or rephrase your question?"
return generated_text
except Exception as e:
console.print(f"[red]Error generating response: {e}[/red]")
return f"I'm experiencing technical difficulties. Error: {str(e)[:100]}..."
def _clean_response(self, text: str) -> str:
"""Clean up the generated response."""
# Remove any trailing incomplete sentences
if text.endswith(('.', '!', '?')):
return text
# Find the last complete sentence
import re
sentences = re.split(r'(?<=[.!?])\s+', text)
if len(sentences) > 1:
# Remove the last incomplete sentence
text = ' '.join(sentences[:-1])
return text.strip()
def check_model_availability(self) -> bool:
"""Check if the model is available."""
return self.chatbot is not None
def get_model_info(self) -> Dict:
"""Get information about the loaded model."""
return {
"model_name": self.model_name,
"device": "GPU" if self.device == 0 else "CPU",
"available": self.chatbot is not None
} |