Spaces:
Sleeping
Sleeping
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| import torch | |
| import gradio as gr | |
| import json | |
| import os | |
| # --- Change only these two lines if you update your base or adapter! --- | |
| base_model_name = "unsloth/gemma-2-9b-it-bnb-4bit" | |
| lora_adapter_path = "lingadevaruhp/thoshan_Flash" | |
| # ---------------------------------------------------------------------- | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| attn_implementation="eager" | |
| ) | |
| model = PeftModel.from_pretrained(base_model, lora_adapter_path) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| tokenizer = None | |
| model = None | |
| def load_dataset(): | |
| dataset_files = ["2000-data-set.txt", "flirt_dataset.jsonl"] | |
| for dataset_file in dataset_files: | |
| if os.path.exists(dataset_file): | |
| print(f"Found dataset file: {dataset_file}") | |
| if dataset_file.endswith('.jsonl'): | |
| dataset_entries = [] | |
| try: | |
| with open(dataset_file, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| try: | |
| entry = json.loads(line.strip()) | |
| dataset_entries.append(entry) | |
| except json.JSONDecodeError: | |
| continue | |
| return dataset_entries | |
| except Exception as e: | |
| print(f"Error reading JSONL file {dataset_file}: {e}") | |
| continue | |
| else: | |
| try: | |
| with open(dataset_file, 'r', encoding='utf-8') as f: | |
| content = f.read().strip() | |
| if content.startswith('<!DOCTYPE html>') or '<html>' in content: | |
| print(f"Skipping HTML file: {dataset_file}") | |
| continue | |
| sample_entries = [ | |
| {"input": "Hello", "output": "Hi there! How are you doing today?"}, | |
| {"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"}, | |
| {"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"} | |
| ] | |
| return sample_entries | |
| except Exception as e: | |
| print(f"Error reading text file {dataset_file}: {e}") | |
| continue | |
| print("No valid dataset file found, using default responses") | |
| return [ | |
| {"input": "Hello", "output": "Hi there! How are you doing today?"}, | |
| {"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"}, | |
| {"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"} | |
| ] | |
| dataset_content = load_dataset() | |
| print(f"Loaded {len(dataset_content)} dataset entries") | |
| def generate_response(prompt, max_new_tokens=100): | |
| if model is None or tokenizer is None: | |
| return "Error: Model failed to load. Please check the logs and try restarting the space." | |
| try: | |
| context = "" | |
| if dataset_content: | |
| context_entries = dataset_content[:3] | |
| context_text = "" | |
| for entry in context_entries: | |
| if 'input' in entry and 'output' in entry: | |
| context_text += f"User: {entry['input']}\nAssistant: {entry['output']}\n\n" | |
| elif 'text' in entry: | |
| context_text += f"{entry['text']}\n\n" | |
| context = f"Dataset context:\n{context_text}\n" if context_text else "" | |
| formatted_prompt = f"<|user|>\n{context}{prompt}<|end|>\n<|assistant|>\n" | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id, | |
| use_cache=False | |
| ) | |
| generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
| return generated_text.strip() | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| # Updated Gradio interface with enhanced textbox features | |
| iface = gr.Interface( | |
| fn=generate_response, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Your message", | |
| placeholder="Type your message here...", | |
| lines=4 | |
| ), | |
| gr.Slider(minimum=10, maximum=200, value=100, label="Max New Tokens") | |
| ], | |
| outputs=gr.Textbox( | |
| label="AI Response", | |
| lines=10, | |
| show_copy_button=True | |
| ), | |
| title="thoshan_Flash (Updated with JSONL Dataset)", | |
| description="Chat with AI powered by thoshan_Flash and the new flirt_dataset.jsonl dataset!" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |