from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch import gradio as gr import json import os # --- Change only these two lines if you update your base or adapter! --- base_model_name = "unsloth/gemma-2-9b-it-bnb-4bit" lora_adapter_path = "lingadevaruhp/thoshan_Flash" # ---------------------------------------------------------------------- try: tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) base_model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=True, attn_implementation="eager" ) model = PeftModel.from_pretrained(base_model, lora_adapter_path) except Exception as e: print(f"Error loading model: {e}") tokenizer = None model = None def load_dataset(): dataset_files = ["2000-data-set.txt", "flirt_dataset.jsonl"] for dataset_file in dataset_files: if os.path.exists(dataset_file): print(f"Found dataset file: {dataset_file}") if dataset_file.endswith('.jsonl'): dataset_entries = [] try: with open(dataset_file, 'r', encoding='utf-8') as f: for line in f: try: entry = json.loads(line.strip()) dataset_entries.append(entry) except json.JSONDecodeError: continue return dataset_entries except Exception as e: print(f"Error reading JSONL file {dataset_file}: {e}") continue else: try: with open(dataset_file, 'r', encoding='utf-8') as f: content = f.read().strip() if content.startswith('') or '' in content: print(f"Skipping HTML file: {dataset_file}") continue sample_entries = [ {"input": "Hello", "output": "Hi there! How are you doing today?"}, {"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"}, {"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"} ] return sample_entries except Exception as e: print(f"Error reading text file {dataset_file}: {e}") continue print("No valid dataset file found, using default responses") return [ {"input": "Hello", "output": "Hi there! How are you doing today?"}, {"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"}, {"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"} ] dataset_content = load_dataset() print(f"Loaded {len(dataset_content)} dataset entries") def generate_response(prompt, max_new_tokens=100): if model is None or tokenizer is None: return "Error: Model failed to load. Please check the logs and try restarting the space." try: context = "" if dataset_content: context_entries = dataset_content[:3] context_text = "" for entry in context_entries: if 'input' in entry and 'output' in entry: context_text += f"User: {entry['input']}\nAssistant: {entry['output']}\n\n" elif 'text' in entry: context_text += f"{entry['text']}\n\n" context = f"Dataset context:\n{context_text}\n" if context_text else "" formatted_prompt = f"<|user|>\n{context}{prompt}<|end|>\n<|assistant|>\n" inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.eos_token_id, use_cache=False ) generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) return generated_text.strip() except Exception as e: return f"Error generating response: {str(e)}" # Updated Gradio interface with enhanced textbox features iface = gr.Interface( fn=generate_response, inputs=[ gr.Textbox( label="Your message", placeholder="Type your message here...", lines=4 ), gr.Slider(minimum=10, maximum=200, value=100, label="Max New Tokens") ], outputs=gr.Textbox( label="AI Response", lines=10, show_copy_button=True ), title="thoshan_Flash (Updated with JSONL Dataset)", description="Chat with AI powered by thoshan_Flash and the new flirt_dataset.jsonl dataset!" ) if __name__ == "__main__": iface.launch()