rahul7star commited on
Commit
93af4b7
·
verified ·
1 Parent(s): 68f2574

Update app_strict_lora.py

Browse files
Files changed (1) hide show
  1. app_strict_lora.py +67 -66
app_strict_lora.py CHANGED
@@ -1,55 +1,34 @@
1
- # app.py (LoRA-only loading)
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, pipeline
4
  import torch
5
- import os
6
  import re
7
- import json
8
  import time
9
  from datetime import datetime
10
- from huggingface_hub import model_info
11
-
12
- import os, shutil, glob
13
-
14
- def cleanup_space():
15
- print("🧹 Cleaning up cache and checkpoints...")
16
- paths = [
17
- "/root/.cache/huggingface/hub",
18
- "/root/.cache/torch",
19
- "./qwen-gita-lora",
20
- "./runs",
21
- "./checkpoint*",
22
- "./repo_tmp",
23
- "./tmp",
24
- "/tmp"
25
- ]
26
- for p in paths:
27
- try:
28
- if os.path.isdir(p):
29
- shutil.rmtree(p)
30
- elif os.path.exists(p):
31
- os.remove(p)
32
- except Exception as e:
33
- print("⚠️ Skip cleanup for", p, e)
34
-
35
- cleanup_space()
36
-
37
-
38
- # ===== Settings =====
39
- device = 0 if torch.cuda.is_available() else -1
40
- lora_repo = "rahul7star/GPT-Diffuser-v1" # ONLY LoRA fine-tuned repo
41
 
 
 
 
 
 
42
  log_lines = []
43
 
44
- def log(msg):
 
 
 
 
45
  line = f"[{datetime.now().strftime('%H:%M:%S')}] {msg}"
46
  print(line)
47
  log_lines.append(line)
48
 
49
- log(f"🚀 Loading LoRA-only model from {lora_repo}")
50
- log(f"Device: {'GPU' if device==0 else 'CPU'}")
51
 
52
- # ====== Tokenizer ======
 
 
 
 
 
53
  try:
54
  tokenizer = AutoTokenizer.from_pretrained(lora_repo, trust_remote_code=True)
55
  if tokenizer.pad_token is None:
@@ -59,9 +38,6 @@ except Exception as e:
59
  log(f"❌ Tokenizer load failed: {e}")
60
  tokenizer = None
61
 
62
- # ====== LoRA-only model ======
63
- model = None
64
- pipe = None
65
  try:
66
  model = AutoModelForCausalLM.from_pretrained(
67
  lora_repo,
@@ -70,18 +46,21 @@ try:
70
  device_map="auto" if torch.cuda.is_available() else None,
71
  )
72
  model.eval()
73
- log("✅ LoRA-only model loaded successfully")
74
  pipe = pipeline(
75
  "text-generation",
76
  model=model,
77
  tokenizer=tokenizer,
78
  device=device,
79
  )
80
- log("✅ Pipeline ready for inference")
81
  except Exception as e:
82
  log(f"❌ LoRA model load failed: {e}")
 
 
83
 
84
- # ====== Chat Function ======
 
 
85
  def chat_with_model(message, history):
86
  log_lines.clear()
87
  log(f"💭 User message: {message}")
@@ -89,61 +68,83 @@ def chat_with_model(message, history):
89
  if pipe is None:
90
  return "", history, "⚠️ Model pipeline not loaded."
91
 
 
92
  context = (
93
- "You are a coding assistant **fine-tuned exclusively on the Hugging Face Diffusers GitHub repository** "
94
- "(https://github.com/huggingface/diffusers.git). "
95
- "Answer questions strictly based on that repository’s Python source code, docstrings, and implementation details. "
96
- "If the answer cannot be found or inferred directly from the diffusers codebase, respond with:\n"
97
- "\"I don’t have enough information from the diffusers repository to answer that.\"\n\n"
98
- "Conversation:\n"
99
- )
 
 
100
 
 
101
  for user, bot in history:
102
  context += f"User: {user}\nAssistant: {bot}\n"
103
  context += f"User: {message}\nAssistant:"
104
 
105
  log("📄 Built conversation context")
106
- log(context)
107
 
 
108
  start_time = time.time()
109
  try:
110
- output = pipe(
111
  context,
112
- max_new_tokens=200,
113
  do_sample=True,
114
  temperature=0.7,
115
  top_p=0.9,
116
  repetition_penalty=1.1,
117
  )[0]["generated_text"]
118
- log(f"⏱️ Inference took {time.time() - start_time:.2f}s")
 
119
  except Exception as e:
120
  log(f"❌ Generation failed: {e}")
121
  return "", history, "\n".join(log_lines)
122
 
123
- # Clean reply
124
- reply = output[len(context):].strip()
125
- reply = re.sub(r"(ContentLoaded|<\/?[^>]+>|[\r\n]{2,})", " ", reply)
126
  reply = re.sub(r"\s{2,}", " ", reply).strip()
127
  reply = reply.split("User:")[0].split("Assistant:")[0].strip()
128
 
129
- log(f"🪄 Model reply: {reply}")
 
 
 
 
 
 
 
 
130
  history.append((message, reply))
131
  return "", history, "\n".join(log_lines)
132
 
133
- # ===== Gradio =====
 
 
 
134
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
135
- gr.Markdown("## 💬 Qwen LoRA-onlyBhagavad Gita Assistant")
136
 
137
  with gr.Row():
138
  with gr.Column(scale=2):
139
- chatbot = gr.Chatbot(height=500)
140
- msg = gr.Textbox(placeholder="Ask about the Gita...", label="Your Message")
141
- clear = gr.Button("Clear")
 
142
  with gr.Column(scale=1):
143
  log_box = gr.Textbox(label="Detailed Model Log", lines=25, interactive=False)
144
 
 
145
  msg.submit(chat_with_model, [msg, chatbot], [msg, chatbot, log_box])
146
  clear.click(lambda: (None, None, ""), None, [chatbot, log_box], queue=False)
147
 
 
 
 
 
148
  if __name__ == "__main__":
149
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # app.py — LoRA Chat Assistant (Diffusers-specialized)
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  import torch
 
5
  import re
 
6
  import time
7
  from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # ==========================================================
10
+ # Configuration
11
+ # ==========================================================
12
+ lora_repo = "rahul7star/GPT-Diffuser-v1" # your fine-tuned LoRA model
13
+ device = 0 if torch.cuda.is_available() else -1
14
  log_lines = []
15
 
16
+
17
+ # ==========================================================
18
+ # Logging helper
19
+ # ==========================================================
20
+ def log(msg: str):
21
  line = f"[{datetime.now().strftime('%H:%M:%S')}] {msg}"
22
  print(line)
23
  log_lines.append(line)
24
 
 
 
25
 
26
+ # ==========================================================
27
+ # Model & Tokenizer Loading
28
+ # ==========================================================
29
+ log(f"🚀 Loading LoRA model from {lora_repo}")
30
+ log(f"Device: {'GPU' if device == 0 else 'CPU'}")
31
+
32
  try:
33
  tokenizer = AutoTokenizer.from_pretrained(lora_repo, trust_remote_code=True)
34
  if tokenizer.pad_token is None:
 
38
  log(f"❌ Tokenizer load failed: {e}")
39
  tokenizer = None
40
 
 
 
 
41
  try:
42
  model = AutoModelForCausalLM.from_pretrained(
43
  lora_repo,
 
46
  device_map="auto" if torch.cuda.is_available() else None,
47
  )
48
  model.eval()
 
49
  pipe = pipeline(
50
  "text-generation",
51
  model=model,
52
  tokenizer=tokenizer,
53
  device=device,
54
  )
55
+ log("✅ LoRA model & pipeline ready for inference")
56
  except Exception as e:
57
  log(f"❌ LoRA model load failed: {e}")
58
+ pipe = None
59
+
60
 
61
+ # ==========================================================
62
+ # Chat Function
63
+ # ==========================================================
64
  def chat_with_model(message, history):
65
  log_lines.clear()
66
  log(f"💭 User message: {message}")
 
68
  if pipe is None:
69
  return "", history, "⚠️ Model pipeline not loaded."
70
 
71
+ # Context — restrict to the trained domain (Diffusers GitHub repo)
72
  context = (
73
+ "You are an expert coding assistant fine-tuned exclusively on the "
74
+ "Hugging Face Diffusers GitHub repository "
75
+ "(https://github.com/huggingface/diffusers.git). "
76
+ "Answer questions strictly based on that repository’s Python source code, "
77
+ "classes, functions, utilities, and docstrings. "
78
+ "If the answer cannot be found in the diffusers repo, respond with:\n"
79
+ "\"I don’t have enough information from the diffusers repository to answer that.\"\n\n"
80
+ "Conversation:\n"
81
+ )
82
 
83
+ # Build chat context
84
  for user, bot in history:
85
  context += f"User: {user}\nAssistant: {bot}\n"
86
  context += f"User: {message}\nAssistant:"
87
 
88
  log("📄 Built conversation context")
 
89
 
90
+ # Generate
91
  start_time = time.time()
92
  try:
93
+ outputs = pipe(
94
  context,
95
+ max_new_tokens=512, # 🔹 extended token limit
96
  do_sample=True,
97
  temperature=0.7,
98
  top_p=0.9,
99
  repetition_penalty=1.1,
100
  )[0]["generated_text"]
101
+ elapsed = time.time() - start_time
102
+ log(f"⏱️ Inference took {elapsed:.2f}s")
103
  except Exception as e:
104
  log(f"❌ Generation failed: {e}")
105
  return "", history, "\n".join(log_lines)
106
 
107
+ # Extract assistant reply
108
+ reply = outputs[len(context):].strip()
109
+ reply = re.sub(r"(<[^>]+>|[\r\n]{3,})", "\n", reply)
110
  reply = re.sub(r"\s{2,}", " ", reply).strip()
111
  reply = reply.split("User:")[0].split("Assistant:")[0].strip()
112
 
113
+ # Fallback if empty or nonsense
114
+ if not reply or len(reply) < 5:
115
+ reply = "I don’t have enough information from the diffusers repository to answer that."
116
+
117
+ # Format code blocks for Gradio UI
118
+ if re.search(r"```|class |def |import ", reply):
119
+ reply = f"```python\n{reply}\n```"
120
+
121
+ log(f"🪄 Model reply: {reply[:200]}...") # preview first 200 chars
122
  history.append((message, reply))
123
  return "", history, "\n".join(log_lines)
124
 
125
+
126
+ # ==========================================================
127
+ # Gradio Interface
128
+ # ==========================================================
129
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
130
+ gr.Markdown("## 🤖 Diffusers LoRA Chat GitHub Code-Trained Assistant")
131
 
132
  with gr.Row():
133
  with gr.Column(scale=2):
134
+ chatbot = gr.Chatbot(height=500, label="Chat with Diffusers LoRA")
135
+ msg = gr.Textbox(placeholder="Ask about Diffusers code...", label="Your Message")
136
+ send = gr.Button("💬 Ask")
137
+ clear = gr.Button("🧹 Clear")
138
  with gr.Column(scale=1):
139
  log_box = gr.Textbox(label="Detailed Model Log", lines=25, interactive=False)
140
 
141
+ send.click(chat_with_model, [msg, chatbot], [msg, chatbot, log_box])
142
  msg.submit(chat_with_model, [msg, chatbot], [msg, chatbot, log_box])
143
  clear.click(lambda: (None, None, ""), None, [chatbot, log_box], queue=False)
144
 
145
+
146
+ # ==========================================================
147
+ # Run App
148
+ # ==========================================================
149
  if __name__ == "__main__":
150
+ demo.launch(server_name="0.0.0.0", server_port=7860)