File size: 5,411 Bytes
2e861b0
7d955d5
3ee9356
2e861b0
06546fe
 
3ee9356
7d955d5
 
 
 
2e861b0
811c844
7d955d5
 
 
811c844
 
 
8625cb1
7d955d5
811c844
7d955d5
811c844
 
 
 
3ee9356
06546fe
811c844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc50fe8
811c844
 
97aef11
 
 
811c844
 
 
 
 
 
 
 
fc50fe8
811c844
 
 
 
 
 
06546fe
 
811c844
06546fe
 
811c844
 
 
 
 
7d955d5
811c844
 
 
 
 
 
 
 
7d955d5
811c844
 
 
 
 
 
 
8625cb1
7d955d5
811c844
 
 
 
 
3ee9356
97aef11
2e861b0
 
8d219ad
97aef11
 
 
f28fe9e
97aef11
06546fe
8d219ad
97aef11
 
 
f28fe9e
97aef11
2956b29
 
2e861b0
 
 
97aef11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import gradio as gr
import json
import os

# --- Change only these two lines if you update your base or adapter! ---
base_model_name = "unsloth/gemma-2-9b-it-bnb-4bit"
lora_adapter_path = "lingadevaruhp/thoshan_Flash"
# ----------------------------------------------------------------------

try:
    tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        attn_implementation="eager"
    )
    model = PeftModel.from_pretrained(base_model, lora_adapter_path)
except Exception as e:
    print(f"Error loading model: {e}")
    tokenizer = None
    model = None

def load_dataset():
    dataset_files = ["2000-data-set.txt", "flirt_dataset.jsonl"]
    for dataset_file in dataset_files:
        if os.path.exists(dataset_file):
            print(f"Found dataset file: {dataset_file}")
            if dataset_file.endswith('.jsonl'):
                dataset_entries = []
                try:
                    with open(dataset_file, 'r', encoding='utf-8') as f:
                        for line in f:
                            try:
                                entry = json.loads(line.strip())
                                dataset_entries.append(entry)
                            except json.JSONDecodeError:
                                continue
                    return dataset_entries
                except Exception as e:
                    print(f"Error reading JSONL file {dataset_file}: {e}")
                    continue
            else:
                try:
                    with open(dataset_file, 'r', encoding='utf-8') as f:
                        content = f.read().strip()
                        if content.startswith('<!DOCTYPE html>') or '<html>' in content:
                            print(f"Skipping HTML file: {dataset_file}")
                            continue
                    sample_entries = [
                        {"input": "Hello", "output": "Hi there! How are you doing today?"},
                        {"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"},
                        {"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"}
                    ]
                    return sample_entries
                except Exception as e:
                    print(f"Error reading text file {dataset_file}: {e}")
                    continue
    print("No valid dataset file found, using default responses")
    return [
        {"input": "Hello", "output": "Hi there! How are you doing today?"},
        {"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"},
        {"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"}
    ]

dataset_content = load_dataset()
print(f"Loaded {len(dataset_content)} dataset entries")

def generate_response(prompt, max_new_tokens=100):
    if model is None or tokenizer is None:
        return "Error: Model failed to load. Please check the logs and try restarting the space."
    try:
        context = ""
        if dataset_content:
            context_entries = dataset_content[:3]
            context_text = ""
            for entry in context_entries:
                if 'input' in entry and 'output' in entry:
                    context_text += f"User: {entry['input']}\nAssistant: {entry['output']}\n\n"
                elif 'text' in entry:
                    context_text += f"{entry['text']}\n\n"
            context = f"Dataset context:\n{context_text}\n" if context_text else ""
        formatted_prompt = f"<|user|>\n{context}{prompt}<|end|>\n<|assistant|>\n"
        inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id,
                use_cache=False
            )
        generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        return generated_text.strip()
    except Exception as e:
        return f"Error generating response: {str(e)}"

# Updated Gradio interface with enhanced textbox features
iface = gr.Interface(
    fn=generate_response,
    inputs=[
        gr.Textbox(
            label="Your message", 
            placeholder="Type your message here...",
            lines=4
        ),
        gr.Slider(minimum=10, maximum=200, value=100, label="Max New Tokens")
    ],
    outputs=gr.Textbox(
        label="AI Response",
        lines=10,
        show_copy_button=True
    ),
    title="thoshan_Flash (Updated with JSONL Dataset)",
    description="Chat with AI powered by thoshan_Flash and the new flirt_dataset.jsonl dataset!"
)

if __name__ == "__main__":
    iface.launch()