# app.py — Fixed: load quantized base + local LoRA checkpoint (preferred), # tokenizer from base, device-safe generation, Gradio UI with sliders. import os import gradio as gr import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, ) from peft import PeftModel # ---- USER CONFIG ---- # If ADAPTER_LOCAL_DIR exists, that local checkpoint (e.g. checkpoint-9000) will be used. ADAPTER_LOCAL_DIR = os.environ.get("ADAPTER_LOCAL_DIR", "qwen_lora_sft_output/checkpoint-9000") HF_ADAPTER_REPO = "GilbertAkham/gilbert-qwen-multitask-lora" # fallback adapter repo id BASE_MODEL = "Qwen/Qwen1.5-1.8B-Chat" # --------------------- class MultitaskInference: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = None self.tokenizer = None self._load_model_and_tokenizer() def _load_model_and_tokenizer(self): compute_dtype = torch.float16 if self.device == "cuda" else torch.float32 # Use tokenizer from base model (recommended) print("Loading tokenizer from base model:", BASE_MODEL) try: self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False, trust_remote_code=True) except Exception as e: print("Failed to load tokenizer from base model:", e) print("Trying tokenizer from local adapter or HF adapter repo as fallback...") # fallback attempt try: self.tokenizer = AutoTokenizer.from_pretrained(HF_ADAPTER_REPO, use_fast=False, trust_remote_code=True) except Exception as e2: raise RuntimeError("Cannot load tokenizer from base or adapter repos.") from e2 if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Prepare bitsandbytes config when CUDA is available bnb_config = None if self.device == "cuda": bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=compute_dtype, ) print("Using 4-bit quantized loader (bitsandbytes) for the base model.") # Load the base model (quantized if possible) print("Loading base model:", BASE_MODEL) try: self.base = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto" if self.device == "cuda" else None, quantization_config=bnb_config, torch_dtype=compute_dtype if self.device == "cuda" else torch.float32, trust_remote_code=True, ) except Exception as e: raise RuntimeError(f"Failed to load base model {BASE_MODEL}: {e}") # Load LoRA adapter: prefer local checkpoint folder if present adapter_source = None if os.path.exists(ADAPTER_LOCAL_DIR) and os.path.isdir(ADAPTER_LOCAL_DIR): adapter_source = ADAPTER_LOCAL_DIR print("Found local adapter checkpoint:", ADAPTER_LOCAL_DIR) else: adapter_source = HF_ADAPTER_REPO print("Local adapter not found — will try to load adapter from HF repo:", HF_ADAPTER_REPO) print(f"Loading LoRA adapter from: {adapter_source}") try: # PeftModel.from_pretrained can accept a local path or a repo id self.model = PeftModel.from_pretrained(self.base, adapter_source, torch_dtype=compute_dtype if self.device == "cuda" else torch.float32) except Exception as e: raise RuntimeError(f"Failed to load LoRA adapter from {adapter_source}: {e}") # Move model to device (PeftModel wraps base model) if self.device == "cuda": # model is partitioned by device_map if bnb used; still ensure on cuda try: self.model.to(self.device) except Exception: # sometimes .to('cuda') is not required when device_map='auto' already placed weights pass else: self.model.to(self.device) self.model.eval() print("Model + adapter loaded. Device:", self.device) def generate_response(self, task_type: str, input_text: str, max_new_tokens: int = 200, temperature: float = 0.7, top_p: float = 0.9): task_prompts = { "email": "Draft an email reply", "story": "Continue the story", "tech": "Answer the technical question", "summary": "Summarize the content", "chat": "Provide a helpful chat response" } prompt = f"### Task: {task_prompts.get(task_type,'Provide a reply')}\n\n### Input:\n{input_text}\n\n### Output:\n" # Tokenize then move tensors to same device as model inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) # Move inputs to model device inputs = {k: v.to(self.model.device) for k, v in inputs.items()} try: with torch.no_grad(): out = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.1, ) text = self.tokenizer.decode(out[0], skip_special_tokens=True) if "### Output:" in text: text = text.split("### Output:")[-1].strip() return text except Exception as e: return f"❌ Generation error: {e}" # Create engine (this will load model on startup) engine = MultitaskInference() # Gradio UI def process_request(task_type, user_input, max_tokens, temperature, top_p): if not user_input or not user_input.strip(): return "⚠️ Please enter some input text." return engine.generate_response(task_type, user_input, max_new_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p)) examples = [ ["chat", "Hey — my VPN won't connect. Any suggestions?"], ["email", "Subject: Project update\nBody: Please share the status of Task A."], ["story", "The lighthouse blinked twice and the fog rolled in..."], ["tech", "What is the difference between model.eval() and model.train() in PyTorch?"], ["summary", "AI systems are transforming industries through automation and data insights..."], ] with gr.Blocks(title="Gilbert Multitask AI", theme=gr.themes.Soft()) as demo: gr.Markdown( f"## 🚀 Gilbert Multitask AI\n\n**Base model:** {BASE_MODEL}\n\nLoRA adapter: local `{ADAPTER_LOCAL_DIR}` if present, otherwise `{HF_ADAPTER_REPO}`." ) with gr.Row(): with gr.Column(scale=1): task_type = gr.Dropdown(choices=["chat", "email", "story", "tech", "summary"], value="chat", label="Task") max_tokens = gr.Slider(50, 1024, value=200, step=10, label="Max new tokens") temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") gr.Examples(examples=examples, inputs=[task_type, gr.Textbox(visible=False)]) with gr.Column(scale=2): input_box = gr.Textbox(lines=8, label="Input") output_box = gr.Textbox(lines=10, label="Generated Response", show_copy_button=True) btn = gr.Button("Generate") btn.click(process_request, inputs=[task_type, input_box, max_tokens, temperature, top_p], outputs=output_box) input_box.submit(process_request, inputs=[task_type, input_box, max_tokens, temperature, top_p], outputs=output_box) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False)