Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| import re | |
| import time | |
| from datetime import datetime | |
| # ========================================================== | |
| # Configuration | |
| # ========================================================== | |
| LORA_REPO = "rahul7star/GPT-Diffuser-v1" # fine-tuned LoRA model (Diffusers-based) | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| LOG_LINES = [] | |
| # ========================================================== | |
| # Logging helper | |
| # ========================================================== | |
| def log(msg: str): | |
| line = f"[{datetime.now().strftime('%H:%M:%S')}] {msg}" | |
| print(line) | |
| LOG_LINES.append(line) | |
| # ========================================================== | |
| # Model & Tokenizer Loading | |
| # ========================================================== | |
| log(f"🚀 Loading Diffusers LoRA model from {LORA_REPO}") | |
| log(f"Device: {'GPU' if DEVICE == 0 else 'CPU'}") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(LORA_REPO, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| log(f"✅ Tokenizer loaded: vocab size {tokenizer.vocab_size}") | |
| except Exception as e: | |
| log(f"❌ Tokenizer load failed: {e}") | |
| tokenizer = None | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| LORA_REPO, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| model.eval() | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=DEVICE) | |
| log("✅ LoRA model pipeline ready for inference") | |
| except Exception as e: | |
| log(f"❌ Model pipeline load failed: {e}") | |
| pipe = None | |
| # ========================================================== | |
| # Chat Function | |
| # ========================================================== | |
| def chat_with_model(message, history): | |
| LOG_LINES.clear() | |
| log(f"💭 User message: {message}") | |
| if pipe is None: | |
| return "", history, "⚠️ Model pipeline not loaded." | |
| # --- STRICT CONTEXT ENFORCEMENT --- | |
| # Model can only use knowledge from diffusers GitHub repo | |
| # LORA_REPO = "rahul7star/GPT-Diffuser-v1" # fine-tuned LoRA | |
| context = ( | |
| f"You are an AI assistant fine-tuned exclusively with the LoRA model from " | |
| f"'{LORA_REPO}'. " | |
| "Answer strictly using knowledge, code, classes, functions, or documentation " | |
| "learned by this LoRA. " | |
| "Do not reference any other models, frameworks, tutorials, blogs, or external sources. " | |
| "If the answer cannot be derived from this LoRA, respond with:\n\n" | |
| "\"I don’t have enough information from this LoRA to answer that.\"\n\n" | |
| "Conversation:\n" | |
| ) | |
| # Build conversation history | |
| for user, bot in history: | |
| context += f"User: {user}\nAssistant: {bot}\n" | |
| context += f"User: {message}\nAssistant:" | |
| log("📄 Built conversation context") | |
| # --- Generation --- | |
| start_time = time.time() | |
| try: | |
| outputs = pipe( | |
| context, | |
| max_new_tokens=512, # extended token limit | |
| do_sample=True, | |
| temperature=0.6, | |
| top_p=0.9, | |
| repetition_penalty=1.15, | |
| )[0]["generated_text"] | |
| elapsed = time.time() - start_time | |
| log(f"⏱️ Inference took {elapsed:.2f}s") | |
| except Exception as e: | |
| log(f"❌ Generation failed: {e}") | |
| return "", history, "\n".join(LOG_LINES) | |
| # --- Clean response --- | |
| print("output====") | |
| print(outputs) | |
| reply = outputs[len(context):].strip() | |
| reply = re.sub(r"(<[^>]+>|[\r\n]{3,})", "\n", reply) | |
| reply = re.sub(r"\s{2,}", " ", reply).strip() | |
| reply = reply.split("User:")[0].split("Assistant:")[0].strip() | |
| print(reply) | |
| # --- Guardrail: only use diffusers context --- | |
| if ( | |
| not reply | |
| or len(reply) < 5 | |
| or re.search(r"(Fluent|OpenAI|Stable|blog|Medium|notebook|paper)", reply, re.I) | |
| ): | |
| reply = "I don’t have enough information from the diffusers repository to answer that." | |
| # --- Markdown-friendly formatting --- | |
| if re.search(r"```|class |def |import ", reply): | |
| reply = f"```python\n{reply}\n```" | |
| log(f"🪄 Model reply: {reply[:180]}...") # preview short part | |
| history.append((message, reply)) | |
| return "", history, "\n".join(LOG_LINES) | |
| # ========================================================== | |
| # Gradio Interface | |
| # ========================================================== | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
| #gr.Markdown("## 🤖 Diffusers GitHub-Trained LoRA Chat Assistant") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(height=480, label="Chat with LoRA") | |
| msg = gr.Textbox( | |
| placeholder="Ask about Diffusers source code, classes, or examples...", | |
| label="Your Message" | |
| ) | |
| send = gr.Button("💬 Ask") | |
| clear = gr.Button("🧹 Clear Chat") | |
| with gr.Column(scale=1): | |
| log_box = gr.Textbox(label="Detailed Model Log", lines=25, interactive=False) | |
| send.click(chat_with_model, [msg, chatbot], [msg, chatbot, log_box]) | |
| msg.submit(chat_with_model, [msg, chatbot], [msg, chatbot, log_box]) | |
| clear.click(lambda: (None, None, ""), None, [chatbot, log_box], queue=False) | |
| # ========================================================== | |
| # Run App | |
| # ========================================================== | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |