rahul7star commited on
Commit
c79a78a
·
verified ·
1 Parent(s): 93af4b7

Update app_strict_lora.py

Browse files
Files changed (1) hide show
  1. app_strict_lora.py +43 -41
app_strict_lora.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py — LoRA Chat Assistant (Diffusers-specialized)
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  import torch
@@ -9,9 +8,9 @@ from datetime import datetime
9
  # ==========================================================
10
  # Configuration
11
  # ==========================================================
12
- lora_repo = "rahul7star/GPT-Diffuser-v1" # your fine-tuned LoRA model
13
- device = 0 if torch.cuda.is_available() else -1
14
- log_lines = []
15
 
16
 
17
  # ==========================================================
@@ -20,17 +19,17 @@ log_lines = []
20
  def log(msg: str):
21
  line = f"[{datetime.now().strftime('%H:%M:%S')}] {msg}"
22
  print(line)
23
- log_lines.append(line)
24
 
25
 
26
  # ==========================================================
27
  # Model & Tokenizer Loading
28
  # ==========================================================
29
- log(f"🚀 Loading LoRA model from {lora_repo}")
30
- log(f"Device: {'GPU' if device == 0 else 'CPU'}")
31
 
32
  try:
33
- tokenizer = AutoTokenizer.from_pretrained(lora_repo, trust_remote_code=True)
34
  if tokenizer.pad_token is None:
35
  tokenizer.pad_token = tokenizer.eos_token
36
  log(f"✅ Tokenizer loaded: vocab size {tokenizer.vocab_size}")
@@ -40,21 +39,16 @@ except Exception as e:
40
 
41
  try:
42
  model = AutoModelForCausalLM.from_pretrained(
43
- lora_repo,
44
  trust_remote_code=True,
45
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
46
  device_map="auto" if torch.cuda.is_available() else None,
47
  )
48
  model.eval()
49
- pipe = pipeline(
50
- "text-generation",
51
- model=model,
52
- tokenizer=tokenizer,
53
- device=device,
54
- )
55
- log("✅ LoRA model & pipeline ready for inference")
56
  except Exception as e:
57
- log(f"❌ LoRA model load failed: {e}")
58
  pipe = None
59
 
60
 
@@ -62,79 +56,87 @@ except Exception as e:
62
  # Chat Function
63
  # ==========================================================
64
  def chat_with_model(message, history):
65
- log_lines.clear()
66
  log(f"💭 User message: {message}")
67
 
68
  if pipe is None:
69
  return "", history, "⚠️ Model pipeline not loaded."
70
 
71
- # Context restrict to the trained domain (Diffusers GitHub repo)
 
72
  context = (
73
- "You are an expert coding assistant fine-tuned exclusively on the "
74
- "Hugging Face Diffusers GitHub repository "
75
- "(https://github.com/huggingface/diffusers.git). "
76
- "Answer questions strictly based on that repository’s Python source code, "
77
- "classes, functions, utilities, and docstrings. "
78
- "If the answer cannot be found in the diffusers repo, respond with:\n"
79
  "\"I don’t have enough information from the diffusers repository to answer that.\"\n\n"
80
  "Conversation:\n"
81
  )
82
 
83
- # Build chat context
84
  for user, bot in history:
85
  context += f"User: {user}\nAssistant: {bot}\n"
86
  context += f"User: {message}\nAssistant:"
87
 
88
  log("📄 Built conversation context")
89
 
90
- # Generate
91
  start_time = time.time()
92
  try:
93
  outputs = pipe(
94
  context,
95
- max_new_tokens=512, # 🔹 extended token limit
96
  do_sample=True,
97
- temperature=0.7,
98
  top_p=0.9,
99
- repetition_penalty=1.1,
100
  )[0]["generated_text"]
101
  elapsed = time.time() - start_time
102
  log(f"⏱️ Inference took {elapsed:.2f}s")
103
  except Exception as e:
104
  log(f"❌ Generation failed: {e}")
105
- return "", history, "\n".join(log_lines)
106
 
107
- # Extract assistant reply
108
  reply = outputs[len(context):].strip()
109
  reply = re.sub(r"(<[^>]+>|[\r\n]{3,})", "\n", reply)
110
  reply = re.sub(r"\s{2,}", " ", reply).strip()
111
  reply = reply.split("User:")[0].split("Assistant:")[0].strip()
112
 
113
- # Fallback if empty or nonsense
114
- if not reply or len(reply) < 5:
 
 
 
 
115
  reply = "I don’t have enough information from the diffusers repository to answer that."
116
 
117
- # Format code blocks for Gradio UI
118
  if re.search(r"```|class |def |import ", reply):
119
  reply = f"```python\n{reply}\n```"
120
 
121
- log(f"🪄 Model reply: {reply[:200]}...") # preview first 200 chars
122
  history.append((message, reply))
123
- return "", history, "\n".join(log_lines)
124
 
125
 
126
  # ==========================================================
127
  # Gradio Interface
128
  # ==========================================================
129
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
130
- gr.Markdown("## 🤖 Diffusers LoRA Chat — GitHub Code-Trained Assistant")
131
 
132
  with gr.Row():
133
  with gr.Column(scale=2):
134
- chatbot = gr.Chatbot(height=500, label="Chat with Diffusers LoRA")
135
- msg = gr.Textbox(placeholder="Ask about Diffusers code...", label="Your Message")
 
 
 
136
  send = gr.Button("💬 Ask")
137
- clear = gr.Button("🧹 Clear")
138
  with gr.Column(scale=1):
139
  log_box = gr.Textbox(label="Detailed Model Log", lines=25, interactive=False)
140
 
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import torch
 
8
  # ==========================================================
9
  # Configuration
10
  # ==========================================================
11
+ LORA_REPO = "rahul7star/GPT-Diffuser-v1" # fine-tuned LoRA model (Diffusers-based)
12
+ DEVICE = 0 if torch.cuda.is_available() else -1
13
+ LOG_LINES = []
14
 
15
 
16
  # ==========================================================
 
19
  def log(msg: str):
20
  line = f"[{datetime.now().strftime('%H:%M:%S')}] {msg}"
21
  print(line)
22
+ LOG_LINES.append(line)
23
 
24
 
25
  # ==========================================================
26
  # Model & Tokenizer Loading
27
  # ==========================================================
28
+ log(f"🚀 Loading Diffusers LoRA model from {LORA_REPO}")
29
+ log(f"Device: {'GPU' if DEVICE == 0 else 'CPU'}")
30
 
31
  try:
32
+ tokenizer = AutoTokenizer.from_pretrained(LORA_REPO, trust_remote_code=True)
33
  if tokenizer.pad_token is None:
34
  tokenizer.pad_token = tokenizer.eos_token
35
  log(f"✅ Tokenizer loaded: vocab size {tokenizer.vocab_size}")
 
39
 
40
  try:
41
  model = AutoModelForCausalLM.from_pretrained(
42
+ LORA_REPO,
43
  trust_remote_code=True,
44
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
45
  device_map="auto" if torch.cuda.is_available() else None,
46
  )
47
  model.eval()
48
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=DEVICE)
49
+ log("✅ LoRA model pipeline ready for inference")
 
 
 
 
 
50
  except Exception as e:
51
+ log(f"❌ Model pipeline load failed: {e}")
52
  pipe = None
53
 
54
 
 
56
  # Chat Function
57
  # ==========================================================
58
  def chat_with_model(message, history):
59
+ LOG_LINES.clear()
60
  log(f"💭 User message: {message}")
61
 
62
  if pipe is None:
63
  return "", history, "⚠️ Model pipeline not loaded."
64
 
65
+ # --- STRICT CONTEXT ENFORCEMENT ---
66
+ # Model can only use knowledge from diffusers GitHub repo
67
  context = (
68
+ "You are an AI assistant fine-tuned exclusively on the Hugging Face Diffusers "
69
+ "GitHub repository (https://github.com/huggingface/diffusers.git). "
70
+ "You must only answer questions using code, classes, functions, or documentation "
71
+ "found within that repository. "
72
+ "Do not reference any other frameworks, blogs, or tutorials. "
73
+ "If the answer cannot be found in the diffusers source code, respond with:\n\n"
74
  "\"I don’t have enough information from the diffusers repository to answer that.\"\n\n"
75
  "Conversation:\n"
76
  )
77
 
78
+ # Build conversation history
79
  for user, bot in history:
80
  context += f"User: {user}\nAssistant: {bot}\n"
81
  context += f"User: {message}\nAssistant:"
82
 
83
  log("📄 Built conversation context")
84
 
85
+ # --- Generation ---
86
  start_time = time.time()
87
  try:
88
  outputs = pipe(
89
  context,
90
+ max_new_tokens=512, # extended token limit
91
  do_sample=True,
92
+ temperature=0.6,
93
  top_p=0.9,
94
+ repetition_penalty=1.15,
95
  )[0]["generated_text"]
96
  elapsed = time.time() - start_time
97
  log(f"⏱️ Inference took {elapsed:.2f}s")
98
  except Exception as e:
99
  log(f"❌ Generation failed: {e}")
100
+ return "", history, "\n".join(LOG_LINES)
101
 
102
+ # --- Clean response ---
103
  reply = outputs[len(context):].strip()
104
  reply = re.sub(r"(<[^>]+>|[\r\n]{3,})", "\n", reply)
105
  reply = re.sub(r"\s{2,}", " ", reply).strip()
106
  reply = reply.split("User:")[0].split("Assistant:")[0].strip()
107
 
108
+ # --- Guardrail: only use diffusers context ---
109
+ if (
110
+ not reply
111
+ or len(reply) < 5
112
+ or re.search(r"(Fluent|OpenAI|Stable|blog|Medium|notebook|paper)", reply, re.I)
113
+ ):
114
  reply = "I don’t have enough information from the diffusers repository to answer that."
115
 
116
+ # --- Markdown-friendly formatting ---
117
  if re.search(r"```|class |def |import ", reply):
118
  reply = f"```python\n{reply}\n```"
119
 
120
+ log(f"🪄 Model reply: {reply[:180]}...") # preview short part
121
  history.append((message, reply))
122
+ return "", history, "\n".join(LOG_LINES)
123
 
124
 
125
  # ==========================================================
126
  # Gradio Interface
127
  # ==========================================================
128
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
129
+ gr.Markdown("## 🤖 Diffusers GitHub-Trained LoRA Chat Assistant")
130
 
131
  with gr.Row():
132
  with gr.Column(scale=2):
133
+ chatbot = gr.Chatbot(height=480, label="Chat with Diffusers LoRA")
134
+ msg = gr.Textbox(
135
+ placeholder="Ask about Diffusers source code, classes, or examples...",
136
+ label="Your Message"
137
+ )
138
  send = gr.Button("💬 Ask")
139
+ clear = gr.Button("🧹 Clear Chat")
140
  with gr.Column(scale=1):
141
  log_box = gr.Textbox(label="Detailed Model Log", lines=25, interactive=False)
142