rahuldhole commited on
Commit
9a82c6a
·
verified ·
1 Parent(s): 61a668b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -22
app.py CHANGED
@@ -4,42 +4,37 @@ from peft import PeftModel
4
  import os
5
 
6
  model_id = "Qwen/Qwen2.5-0.5B-Instruct"
7
- # Path to adapter - local if exists, else load from Hub
 
8
  local_adapter = "outputs/qwen-fine-tuned"
9
- # Environment variables from HF Space secret settings
10
- hf_username = os.getenv("HF_USERNAME")
11
- hf_model_name = os.getenv("HF_MODEL_NAME")
12
- hub_adapter = f"{hf_username}/{hf_model_name}" if hf_username and hf_model_name else None
13
 
14
- # Prioritize local folder but fallback to hub repo
15
  adapter_path = local_adapter if os.path.exists(local_adapter) else hub_adapter
16
 
17
- # Handle device detection for varied environments
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  if not torch.cuda.is_available() and torch.backends.mps.is_available():
20
  device = "mps"
21
 
22
- print(f"Loading model on {device}...")
23
- print(f"Using adapter path: {adapter_path}")
24
 
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
26
- # Load in float16 for memory efficiency
27
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
28
 
29
- if adapter_path:
30
- try:
31
- model = PeftModel.from_pretrained(model, adapter_path)
32
- print("Adapter loaded successfully!")
33
- except Exception as e:
34
- print(f"Warning: Could not load adapter from {adapter_path}: {e}")
35
- else:
36
- print("Warning: No adapter found. Using base model.")
37
 
38
  def chat(message, history):
39
  msgs = [{"role": "user", "content": message}]
40
  text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
41
- model_inputs = tokenizer([text], return_tensors="pt").to(device)
42
- ids = model.generate(**model_inputs, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id)
43
- return tokenizer.decode(ids[0][len(model_inputs.input_ids[0]):], skip_special_tokens=True)
44
 
45
- gr.ChatInterface(chat).launch()
 
4
  import os
5
 
6
  model_id = "Qwen/Qwen2.5-0.5B-Instruct"
7
+
8
+ # Adapter source: local folder first, then Hub repo
9
  local_adapter = "outputs/qwen-fine-tuned"
10
+ hub_adapter = os.getenv("HF_MODEL_NAME", "rahuldhole/tiny-llm-qwen-adapter")
11
+ # Prefix with username if it's just a name
12
+ if "/" not in hub_adapter:
13
+ hub_adapter = f"{os.getenv('HF_USERNAME', 'rahuldhole')}/{hub_adapter}"
14
 
 
15
  adapter_path = local_adapter if os.path.exists(local_adapter) else hub_adapter
16
 
17
+ # Device
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  if not torch.cuda.is_available() and torch.backends.mps.is_available():
20
  device = "mps"
21
 
22
+ print(f"Device: {device} | Adapter: {adapter_path}")
 
23
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
25
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
26
 
27
+ try:
28
+ model = PeftModel.from_pretrained(model, adapter_path)
29
+ print("✅ Adapter loaded!")
30
+ except Exception as e:
31
+ print(f"⚠️ Adapter not loaded ({e}), using base model.")
 
 
 
32
 
33
  def chat(message, history):
34
  msgs = [{"role": "user", "content": message}]
35
  text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
36
+ inputs = tokenizer([text], return_tensors="pt").to(device)
37
+ ids = model.generate(**inputs, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id)
38
+ return tokenizer.decode(ids[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
39
 
40
+ gr.ChatInterface(chat, title="Tiny LLM Chat", description="Chat with a fine-tuned Qwen 0.5B model").launch()