Spaces:
Sleeping
Sleeping
File size: 5,411 Bytes
2e861b0 7d955d5 3ee9356 2e861b0 06546fe 3ee9356 7d955d5 2e861b0 811c844 7d955d5 811c844 8625cb1 7d955d5 811c844 7d955d5 811c844 3ee9356 06546fe 811c844 fc50fe8 811c844 97aef11 811c844 fc50fe8 811c844 06546fe 811c844 06546fe 811c844 7d955d5 811c844 7d955d5 811c844 8625cb1 7d955d5 811c844 3ee9356 97aef11 2e861b0 8d219ad 97aef11 f28fe9e 97aef11 06546fe 8d219ad 97aef11 f28fe9e 97aef11 2956b29 2e861b0 97aef11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import gradio as gr
import json
import os
# --- Change only these two lines if you update your base or adapter! ---
base_model_name = "unsloth/gemma-2-9b-it-bnb-4bit"
lora_adapter_path = "lingadevaruhp/thoshan_Flash"
# ----------------------------------------------------------------------
try:
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="eager"
)
model = PeftModel.from_pretrained(base_model, lora_adapter_path)
except Exception as e:
print(f"Error loading model: {e}")
tokenizer = None
model = None
def load_dataset():
dataset_files = ["2000-data-set.txt", "flirt_dataset.jsonl"]
for dataset_file in dataset_files:
if os.path.exists(dataset_file):
print(f"Found dataset file: {dataset_file}")
if dataset_file.endswith('.jsonl'):
dataset_entries = []
try:
with open(dataset_file, 'r', encoding='utf-8') as f:
for line in f:
try:
entry = json.loads(line.strip())
dataset_entries.append(entry)
except json.JSONDecodeError:
continue
return dataset_entries
except Exception as e:
print(f"Error reading JSONL file {dataset_file}: {e}")
continue
else:
try:
with open(dataset_file, 'r', encoding='utf-8') as f:
content = f.read().strip()
if content.startswith('<!DOCTYPE html>') or '<html>' in content:
print(f"Skipping HTML file: {dataset_file}")
continue
sample_entries = [
{"input": "Hello", "output": "Hi there! How are you doing today?"},
{"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"},
{"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"}
]
return sample_entries
except Exception as e:
print(f"Error reading text file {dataset_file}: {e}")
continue
print("No valid dataset file found, using default responses")
return [
{"input": "Hello", "output": "Hi there! How are you doing today?"},
{"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"},
{"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"}
]
dataset_content = load_dataset()
print(f"Loaded {len(dataset_content)} dataset entries")
def generate_response(prompt, max_new_tokens=100):
if model is None or tokenizer is None:
return "Error: Model failed to load. Please check the logs and try restarting the space."
try:
context = ""
if dataset_content:
context_entries = dataset_content[:3]
context_text = ""
for entry in context_entries:
if 'input' in entry and 'output' in entry:
context_text += f"User: {entry['input']}\nAssistant: {entry['output']}\n\n"
elif 'text' in entry:
context_text += f"{entry['text']}\n\n"
context = f"Dataset context:\n{context_text}\n" if context_text else ""
formatted_prompt = f"<|user|>\n{context}{prompt}<|end|>\n<|assistant|>\n"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
use_cache=False
)
generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return generated_text.strip()
except Exception as e:
return f"Error generating response: {str(e)}"
# Updated Gradio interface with enhanced textbox features
iface = gr.Interface(
fn=generate_response,
inputs=[
gr.Textbox(
label="Your message",
placeholder="Type your message here...",
lines=4
),
gr.Slider(minimum=10, maximum=200, value=100, label="Max New Tokens")
],
outputs=gr.Textbox(
label="AI Response",
lines=10,
show_copy_button=True
),
title="thoshan_Flash (Updated with JSONL Dataset)",
description="Chat with AI powered by thoshan_Flash and the new flirt_dataset.jsonl dataset!"
)
if __name__ == "__main__":
iface.launch() |