Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel, PeftConfig | |
| import torch | |
| # --- 1. Check CUDA Availability and Set Device --- | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| print(f"Using device: {device} ({torch.cuda.get_device_name(0)})") | |
| else: | |
| print("CUDA is not available. Falling back to CPU.") | |
| device = torch.device("cpu") | |
| # --- 2. Load Tokenizer (with error handling) --- | |
| MODEL_PATH = "sagar007/phi2_25k" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| except Exception as e: | |
| print(f"Error loading tokenizer: {e}") | |
| exit() | |
| # --- 3. Load Base Model (Optimized for GPU) --- | |
| try: | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/phi-2", | |
| torch_dtype=torch.float16, # Use float16 on GPU for efficiency | |
| device_map="auto", # Automatically distribute model across GPUs | |
| trust_remote_code=True | |
| ) | |
| except Exception as e: | |
| print(f"Error loading base model: {e}") | |
| exit() | |
| # --- 4. Load PEFT Model (Optimized for GPU) --- | |
| try: | |
| peft_config = PeftConfig.from_pretrained(MODEL_PATH) | |
| model = PeftModel.from_pretrained(base_model, MODEL_PATH) | |
| except Exception as e: | |
| print(f"Error loading PEFT model: {e}") | |
| exit() | |
| # Move model to the GPU | |
| model.to(device) | |
| model.eval() | |
| # --- 5. Generation Function (Optimized for GPU) --- | |
| def generate_response(instruction, max_length=512): | |
| prompt = f"Instruction: {instruction}\nResponse:" | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response.split("Response:")[1].strip() | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| return "Error during response generation." | |
| # --- 6. Gradio Interface --- | |
| def chatbot(message, history): | |
| response = generate_response(message) | |
| return response | |
| demo = gr.ChatInterface( | |
| chatbot, | |
| title="Fine-tuned Phi-2 Chatbot (GPU)", | |
| description="This is a chatbot using a fine-tuned version of the Phi-2 model, running on GPU.", | |
| theme="default", | |
| examples=[ | |
| "Explain the concept of machine learning.", | |
| "Write a short story about a robot learning to paint.", | |
| "What are some effective ways to reduce stress?", | |
| ], | |
| cache_examples=False, # You can enable caching now | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |