rahuldhole commited on
Commit
6a35fde
·
verified ·
1 Parent(s): df5268f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
+ import os
5
+
6
+ model_id = "Qwen/Qwen2.5-0.5B-Instruct"
7
+ # Path to adapter - local if exists, else load from Hub
8
+ local_adapter = "outputs/qwen-fine-tuned"
9
+ # Environment variables from HF Space secret settings
10
+ hf_username = os.getenv("HF_USERNAME")
11
+ hf_model_name = os.getenv("HF_MODEL_NAME")
12
+ hub_adapter = f"{hf_username}/{hf_model_name}" if hf_username and hf_model_name else None
13
+
14
+ # Prioritize local folder but fallback to hub repo
15
+ adapter_path = local_adapter if os.path.exists(local_adapter) else hub_adapter
16
+
17
+ # Handle device detection for varied environments
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ if not torch.cuda.is_available() and torch.backends.mps.is_available():
20
+ device = "mps"
21
+
22
+ print(f"Loading model on {device}...")
23
+ print(f"Using adapter path: {adapter_path}")
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ # Load in float16 for memory efficiency
27
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
28
+
29
+ if adapter_path:
30
+ try:
31
+ model = PeftModel.from_pretrained(model, adapter_path)
32
+ print("Adapter loaded successfully!")
33
+ except Exception as e:
34
+ print(f"Warning: Could not load adapter from {adapter_path}: {e}")
35
+ else:
36
+ print("Warning: No adapter found. Using base model.")
37
+
38
+ def chat(message, history):
39
+ msgs = [{"role": "user", "content": message}]
40
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
41
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
42
+ ids = model.generate(**model_inputs, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id)
43
+ return tokenizer.decode(ids[0][len(model_inputs.input_ids[0]):], skip_special_tokens=True)
44
+
45
+ gr.ChatInterface(chat).launch()