lingadevaruhp's picture
Fix Gradio compatibility by removing unsupported show_clear_button arguments
f28fe9e verified
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import gradio as gr
import json
import os
# --- Change only these two lines if you update your base or adapter! ---
base_model_name = "unsloth/gemma-2-9b-it-bnb-4bit"
lora_adapter_path = "lingadevaruhp/thoshan_Flash"
# ----------------------------------------------------------------------
try:
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="eager"
)
model = PeftModel.from_pretrained(base_model, lora_adapter_path)
except Exception as e:
print(f"Error loading model: {e}")
tokenizer = None
model = None
def load_dataset():
dataset_files = ["2000-data-set.txt", "flirt_dataset.jsonl"]
for dataset_file in dataset_files:
if os.path.exists(dataset_file):
print(f"Found dataset file: {dataset_file}")
if dataset_file.endswith('.jsonl'):
dataset_entries = []
try:
with open(dataset_file, 'r', encoding='utf-8') as f:
for line in f:
try:
entry = json.loads(line.strip())
dataset_entries.append(entry)
except json.JSONDecodeError:
continue
return dataset_entries
except Exception as e:
print(f"Error reading JSONL file {dataset_file}: {e}")
continue
else:
try:
with open(dataset_file, 'r', encoding='utf-8') as f:
content = f.read().strip()
if content.startswith('<!DOCTYPE html>') or '<html>' in content:
print(f"Skipping HTML file: {dataset_file}")
continue
sample_entries = [
{"input": "Hello", "output": "Hi there! How are you doing today?"},
{"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"},
{"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"}
]
return sample_entries
except Exception as e:
print(f"Error reading text file {dataset_file}: {e}")
continue
print("No valid dataset file found, using default responses")
return [
{"input": "Hello", "output": "Hi there! How are you doing today?"},
{"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"},
{"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"}
]
dataset_content = load_dataset()
print(f"Loaded {len(dataset_content)} dataset entries")
def generate_response(prompt, max_new_tokens=100):
if model is None or tokenizer is None:
return "Error: Model failed to load. Please check the logs and try restarting the space."
try:
context = ""
if dataset_content:
context_entries = dataset_content[:3]
context_text = ""
for entry in context_entries:
if 'input' in entry and 'output' in entry:
context_text += f"User: {entry['input']}\nAssistant: {entry['output']}\n\n"
elif 'text' in entry:
context_text += f"{entry['text']}\n\n"
context = f"Dataset context:\n{context_text}\n" if context_text else ""
formatted_prompt = f"<|user|>\n{context}{prompt}<|end|>\n<|assistant|>\n"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
use_cache=False
)
generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return generated_text.strip()
except Exception as e:
return f"Error generating response: {str(e)}"
# Updated Gradio interface with enhanced textbox features
iface = gr.Interface(
fn=generate_response,
inputs=[
gr.Textbox(
label="Your message",
placeholder="Type your message here...",
lines=4
),
gr.Slider(minimum=10, maximum=200, value=100, label="Max New Tokens")
],
outputs=gr.Textbox(
label="AI Response",
lines=10,
show_copy_button=True
),
title="thoshan_Flash (Updated with JSONL Dataset)",
description="Chat with AI powered by thoshan_Flash and the new flirt_dataset.jsonl dataset!"
)
if __name__ == "__main__":
iface.launch()