File size: 5,286 Bytes
0919d5b
 
 
 
 
 
 
 
 
 
 
42a68b0
0919d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e98df0
0919d5b
 
42a68b0
0919d5b
 
 
 
42a68b0
 
 
 
 
 
 
0919d5b
 
 
 
 
 
 
 
 
42a68b0
 
 
 
5e98df0
 
 
 
0919d5b
 
 
 
5e98df0
0919d5b
42a68b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0919d5b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from typing import List, Dict, Optional
from rich.console import Console

console = Console()

class HuggingFaceChat:
    """Interface for chatting with Hugging Face models."""

    def __init__(self, model_name: str = "microsoft/DialoGPT-medium"):
        """
        Initialize the chat interface.

        Args:
            model_name: The Hugging Face model to use
        """
        self.model_name = model_name
        self.device = 0 if torch.cuda.is_available() else -1  # Use GPU if available

        # Try loading the model with safetensors first
        try:
            console.print(f"[blue]Loading model: {model_name}[/blue]")
            # Try to load with safetensors format first
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
            self.model = AutoModelForCausalLM.from_pretrained(model_name, use_safetensors=True)
            self.chatbot = pipeline(
                "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                device=self.device
            )
            console.print("[green]✓ Model loaded successfully[/green]")
        except Exception as e:
            console.print(f"[red]Error loading model with safetensors: {e}[/red]")
            try:
                # Fallback to regular loading
                console.print("[yellow]Trying regular loading...[/yellow]")
                self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                self.model = AutoModelForCausalLM.from_pretrained(model_name)
                self.chatbot = pipeline(
                    "text-generation",
                    model=self.model,
                    tokenizer=self.tokenizer,
                    device=self.device
                )
                console.print("[green]✓ Model loaded successfully[/green]")
            except Exception as e2:
                console.print(f"[red]Error loading model: {e2}[/red]")
                console.print("[yellow]Falling back to simple text generation...[/yellow]")
                self.chatbot = None

    def generate_response(self, prompt: str, max_length: int = 1000) -> str:
        """
        Generate a response to a prompt.

        Args:
            prompt: The input prompt
            max_length: Maximum length of the generated response

        Returns:
            The generated response
        """
        if not self.chatbot:
            return "I'm sorry, but I couldn't load the AI model. This might be due to:\n1. Model loading issues\n2. Internet connection problems\n3. Server maintenance\n\nPlease try again in a few minutes."

        try:
            # Generate response with improved parameters for better quality
            response = self.chatbot(
                prompt,
                max_length=max_length,
                do_sample=True,
                temperature=0.8,  # Higher for more creativity
                top_p=0.95,       # Higher for more diverse responses
                top_k=50,         # Limit to top 50 tokens
                repetition_penalty=1.1,  # Lower penalty for more natural flow
                no_repeat_ngram_size=3,  # Avoid repeating 3-word phrases
                pad_token_id=self.tokenizer.eos_token_id,
                truncation=True   # Enable truncation to prevent errors
            )

            # Extract the generated text
            generated_text = response[0]['generated_text']

            # Remove the prompt from the response if it's included
            if generated_text.startswith(prompt):
                generated_text = generated_text[len(prompt):].strip()

            # Clean up the response
            # Remove any incomplete sentences or hanging punctuation
            generated_text = self._clean_response(generated_text)

            # If response is empty or too short, provide a helpful message
            if not generated_text or len(generated_text.strip()) < 5:
                return "I'm processing your message. Could you please try again or rephrase your question?"

            return generated_text

        except Exception as e:
            console.print(f"[red]Error generating response: {e}[/red]")
            return f"I'm experiencing technical difficulties. Error: {str(e)[:100]}..."

    def _clean_response(self, text: str) -> str:
        """Clean up the generated response."""
        # Remove any trailing incomplete sentences
        if text.endswith(('.', '!', '?')):
            return text

        # Find the last complete sentence
        import re
        sentences = re.split(r'(?<=[.!?])\s+', text)
        if len(sentences) > 1:
            # Remove the last incomplete sentence
            text = ' '.join(sentences[:-1])

        return text.strip()

    def check_model_availability(self) -> bool:
        """Check if the model is available."""
        return self.chatbot is not None

    def get_model_info(self) -> Dict:
        """Get information about the loaded model."""
        return {
            "model_name": self.model_name,
            "device": "GPU" if self.device == 0 else "CPU",
            "available": self.chatbot is not None
        }