GilbertAkham commited on
Commit
96323b7
·
verified ·
1 Parent(s): 2db099b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — Fixed: load quantized base + local LoRA checkpoint (preferred),
2
+ # tokenizer from base, device-safe generation, Gradio UI with sliders.
3
+ import os
4
+ import gradio as gr
5
+ import torch
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModelForCausalLM,
9
+ BitsAndBytesConfig,
10
+ )
11
+ from peft import PeftModel
12
+
13
+ # ---- USER CONFIG ----
14
+ # If ADAPTER_LOCAL_DIR exists, that local checkpoint (e.g. checkpoint-9000) will be used.
15
+ ADAPTER_LOCAL_DIR = os.environ.get("ADAPTER_LOCAL_DIR", "qwen_lora_sft_output/checkpoint-9000")
16
+ HF_ADAPTER_REPO = "GilbertAkham/gilbert-qwen-multitask-lora" # fallback adapter repo id
17
+ BASE_MODEL = "Qwen/Qwen1.5-1.8B-Chat"
18
+ # ---------------------
19
+
20
+ class MultitaskInference:
21
+ def __init__(self):
22
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ self.model = None
24
+ self.tokenizer = None
25
+ self._load_model_and_tokenizer()
26
+
27
+ def _load_model_and_tokenizer(self):
28
+ compute_dtype = torch.float16 if self.device == "cuda" else torch.float32
29
+
30
+ # Use tokenizer from base model (recommended)
31
+ print("Loading tokenizer from base model:", BASE_MODEL)
32
+ try:
33
+ self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False, trust_remote_code=True)
34
+ except Exception as e:
35
+ print("Failed to load tokenizer from base model:", e)
36
+ print("Trying tokenizer from local adapter or HF adapter repo as fallback...")
37
+ # fallback attempt
38
+ try:
39
+ self.tokenizer = AutoTokenizer.from_pretrained(HF_ADAPTER_REPO, use_fast=False, trust_remote_code=True)
40
+ except Exception as e2:
41
+ raise RuntimeError("Cannot load tokenizer from base or adapter repos.") from e2
42
+
43
+ if self.tokenizer.pad_token is None:
44
+ self.tokenizer.pad_token = self.tokenizer.eos_token
45
+
46
+ # Prepare bitsandbytes config when CUDA is available
47
+ bnb_config = None
48
+ if self.device == "cuda":
49
+ bnb_config = BitsAndBytesConfig(
50
+ load_in_4bit=True,
51
+ bnb_4bit_use_double_quant=True,
52
+ bnb_4bit_quant_type="nf4",
53
+ bnb_4bit_compute_dtype=compute_dtype,
54
+ )
55
+ print("Using 4-bit quantized loader (bitsandbytes) for the base model.")
56
+
57
+ # Load the base model (quantized if possible)
58
+ print("Loading base model:", BASE_MODEL)
59
+ try:
60
+ self.base = AutoModelForCausalLM.from_pretrained(
61
+ BASE_MODEL,
62
+ device_map="auto" if self.device == "cuda" else None,
63
+ quantization_config=bnb_config,
64
+ torch_dtype=compute_dtype if self.device == "cuda" else torch.float32,
65
+ trust_remote_code=True,
66
+ )
67
+ except Exception as e:
68
+ raise RuntimeError(f"Failed to load base model {BASE_MODEL}: {e}")
69
+
70
+ # Load LoRA adapter: prefer local checkpoint folder if present
71
+ adapter_source = None
72
+ if os.path.exists(ADAPTER_LOCAL_DIR) and os.path.isdir(ADAPTER_LOCAL_DIR):
73
+ adapter_source = ADAPTER_LOCAL_DIR
74
+ print("Found local adapter checkpoint:", ADAPTER_LOCAL_DIR)
75
+ else:
76
+ adapter_source = HF_ADAPTER_REPO
77
+ print("Local adapter not found — will try to load adapter from HF repo:", HF_ADAPTER_REPO)
78
+
79
+ print(f"Loading LoRA adapter from: {adapter_source}")
80
+ try:
81
+ # PeftModel.from_pretrained can accept a local path or a repo id
82
+ self.model = PeftModel.from_pretrained(self.base, adapter_source, torch_dtype=compute_dtype if self.device == "cuda" else torch.float32)
83
+ except Exception as e:
84
+ raise RuntimeError(f"Failed to load LoRA adapter from {adapter_source}: {e}")
85
+
86
+ # Move model to device (PeftModel wraps base model)
87
+ if self.device == "cuda":
88
+ # model is partitioned by device_map if bnb used; still ensure on cuda
89
+ try:
90
+ self.model.to(self.device)
91
+ except Exception:
92
+ # sometimes .to('cuda') is not required when device_map='auto' already placed weights
93
+ pass
94
+ else:
95
+ self.model.to(self.device)
96
+
97
+ self.model.eval()
98
+ print("Model + adapter loaded. Device:", self.device)
99
+
100
+ def generate_response(self, task_type: str, input_text: str, max_new_tokens: int = 200, temperature: float = 0.7, top_p: float = 0.9):
101
+ task_prompts = {
102
+ "email": "Draft an email reply",
103
+ "story": "Continue the story",
104
+ "tech": "Answer the technical question",
105
+ "summary": "Summarize the content",
106
+ "chat": "Provide a helpful chat response"
107
+ }
108
+ prompt = f"### Task: {task_prompts.get(task_type,'Provide a reply')}\n\n### Input:\n{input_text}\n\n### Output:\n"
109
+
110
+ # Tokenize then move tensors to same device as model
111
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
112
+ # Move inputs to model device
113
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
114
+
115
+ try:
116
+ with torch.no_grad():
117
+ out = self.model.generate(
118
+ **inputs,
119
+ max_new_tokens=max_new_tokens,
120
+ temperature=temperature,
121
+ top_p=top_p,
122
+ do_sample=True,
123
+ pad_token_id=self.tokenizer.eos_token_id,
124
+ repetition_penalty=1.1,
125
+ )
126
+ text = self.tokenizer.decode(out[0], skip_special_tokens=True)
127
+ if "### Output:" in text:
128
+ text = text.split("### Output:")[-1].strip()
129
+ return text
130
+ except Exception as e:
131
+ return f"❌ Generation error: {e}"
132
+
133
+
134
+ # Create engine (this will load model on startup)
135
+ engine = MultitaskInference()
136
+
137
+ # Gradio UI
138
+ def process_request(task_type, user_input, max_tokens, temperature, top_p):
139
+ if not user_input or not user_input.strip():
140
+ return "⚠️ Please enter some input text."
141
+ return engine.generate_response(task_type, user_input, max_new_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p))
142
+
143
+
144
+ examples = [
145
+ ["chat", "Hey — my VPN won't connect. Any suggestions?"],
146
+ ["email", "Subject: Project update\nBody: Please share the status of Task A."],
147
+ ["story", "The lighthouse blinked twice and the fog rolled in..."],
148
+ ["tech", "What is the difference between model.eval() and model.train() in PyTorch?"],
149
+ ["summary", "AI systems are transforming industries through automation and data insights..."],
150
+ ]
151
+
152
+ with gr.Blocks(title="Gilbert Multitask AI", theme=gr.themes.Soft()) as demo:
153
+ gr.Markdown(
154
+ f"## 🚀 Gilbert Multitask AI\n\n**Base model:** {BASE_MODEL}\n\nLoRA adapter: local `{ADAPTER_LOCAL_DIR}` if present, otherwise `{HF_ADAPTER_REPO}`."
155
+ )
156
+
157
+ with gr.Row():
158
+ with gr.Column(scale=1):
159
+ task_type = gr.Dropdown(choices=["chat", "email", "story", "tech", "summary"], value="chat", label="Task")
160
+ max_tokens = gr.Slider(50, 1024, value=200, step=10, label="Max new tokens")
161
+ temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.05, label="Temperature")
162
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
163
+ gr.Examples(examples=examples, inputs=[task_type, gr.Textbox(visible=False)])
164
+ with gr.Column(scale=2):
165
+ input_box = gr.Textbox(lines=8, label="Input")
166
+ output_box = gr.Textbox(lines=10, label="Generated Response", show_copy_button=True)
167
+ btn = gr.Button("Generate")
168
+
169
+ btn.click(process_request, inputs=[task_type, input_box, max_tokens, temperature, top_p], outputs=output_box)
170
+ input_box.submit(process_request, inputs=[task_type, input_box, max_tokens, temperature, top_p], outputs=output_box)
171
+
172
+ if __name__ == "__main__":
173
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)