File size: 5,672 Bytes
eca54e8
93af4b7
eca54e8
 
 
 
 
93af4b7
 
 
c79a78a
 
 
eca54e8
93af4b7
 
 
 
 
eca54e8
 
c79a78a
eca54e8
 
93af4b7
 
 
c79a78a
 
93af4b7
eca54e8
c79a78a
eca54e8
 
 
 
 
 
 
 
 
c79a78a
eca54e8
 
 
 
 
c79a78a
 
eca54e8
c79a78a
93af4b7
 
eca54e8
93af4b7
 
 
eca54e8
c79a78a
eca54e8
 
 
 
 
c79a78a
 
732612d
 
68f2574
732612d
 
 
 
 
 
 
 
 
68f2574
c79a78a
eca54e8
 
 
 
 
 
c79a78a
eca54e8
 
93af4b7
eca54e8
c79a78a
eca54e8
c79a78a
eca54e8
c79a78a
eca54e8
93af4b7
 
eca54e8
 
c79a78a
eca54e8
c79a78a
a8f7dc4
 
93af4b7
 
eca54e8
 
a8f7dc4
c79a78a
 
 
 
 
 
93af4b7
 
c79a78a
93af4b7
 
 
c79a78a
eca54e8
c79a78a
eca54e8
93af4b7
 
 
 
eca54e8
246d52e
eca54e8
 
 
246d52e
c79a78a
 
 
 
93af4b7
c79a78a
eca54e8
 
 
93af4b7
eca54e8
 
 
93af4b7
 
 
 
eca54e8
93af4b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import re
import time
from datetime import datetime

# ==========================================================
# Configuration
# ==========================================================
LORA_REPO = "rahul7star/GPT-Diffuser-v1"  # fine-tuned LoRA model (Diffusers-based)
DEVICE = 0 if torch.cuda.is_available() else -1
LOG_LINES = []


# ==========================================================
# Logging helper
# ==========================================================
def log(msg: str):
    line = f"[{datetime.now().strftime('%H:%M:%S')}] {msg}"
    print(line)
    LOG_LINES.append(line)


# ==========================================================
# Model & Tokenizer Loading
# ==========================================================
log(f"🚀 Loading Diffusers LoRA model from {LORA_REPO}")
log(f"Device: {'GPU' if DEVICE == 0 else 'CPU'}")

try:
    tokenizer = AutoTokenizer.from_pretrained(LORA_REPO, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    log(f"✅ Tokenizer loaded: vocab size {tokenizer.vocab_size}")
except Exception as e:
    log(f"❌ Tokenizer load failed: {e}")
    tokenizer = None

try:
    model = AutoModelForCausalLM.from_pretrained(
        LORA_REPO,
        trust_remote_code=True,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None,
    )
    model.eval()
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=DEVICE)
    log("✅ LoRA model pipeline ready for inference")
except Exception as e:
    log(f"❌ Model pipeline load failed: {e}")
    pipe = None


# ==========================================================
# Chat Function
# ==========================================================
def chat_with_model(message, history):
    LOG_LINES.clear()
    log(f"💭 User message: {message}")

    if pipe is None:
        return "", history, "⚠️ Model pipeline not loaded."

    # --- STRICT CONTEXT ENFORCEMENT ---
    # Model can only use knowledge from diffusers GitHub repo
 #   LORA_REPO = "rahul7star/GPT-Diffuser-v1"  # fine-tuned LoRA

    context = (
    f"You are an AI assistant fine-tuned exclusively with the LoRA model from "
    f"'{LORA_REPO}'. "
    "Answer strictly using knowledge, code, classes, functions, or documentation "
    "learned by this LoRA. "
    "Do not reference any other models, frameworks, tutorials, blogs, or external sources. "
    "If the answer cannot be derived from this LoRA, respond with:\n\n"
    "\"I don’t have enough information from this LoRA to answer that.\"\n\n"
    "Conversation:\n"
     )

    # Build conversation history
    for user, bot in history:
        context += f"User: {user}\nAssistant: {bot}\n"
    context += f"User: {message}\nAssistant:"

    log("📄 Built conversation context")

    # --- Generation ---
    start_time = time.time()
    try:
        outputs = pipe(
            context,
            max_new_tokens=512,   # extended token limit
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            repetition_penalty=1.15,
        )[0]["generated_text"]
        elapsed = time.time() - start_time
        log(f"⏱️ Inference took {elapsed:.2f}s")
    except Exception as e:
        log(f"❌ Generation failed: {e}")
        return "", history, "\n".join(LOG_LINES)

    # --- Clean response ---
    print("output====")
    print(outputs)
    reply = outputs[len(context):].strip()
    reply = re.sub(r"(<[^>]+>|[\r\n]{3,})", "\n", reply)
    reply = re.sub(r"\s{2,}", " ", reply).strip()
    reply = reply.split("User:")[0].split("Assistant:")[0].strip()
    print(reply)
    # --- Guardrail: only use diffusers context ---
    if (
        not reply
        or len(reply) < 5
        or re.search(r"(Fluent|OpenAI|Stable|blog|Medium|notebook|paper)", reply, re.I)
    ):
        reply = "I don’t have enough information from the diffusers repository to answer that."

    # --- Markdown-friendly formatting ---
    if re.search(r"```|class |def |import ", reply):
        reply = f"```python\n{reply}\n```"

    log(f"🪄 Model reply: {reply[:180]}...")  # preview short part
    history.append((message, reply))
    return "", history, "\n".join(LOG_LINES)


# ==========================================================
# Gradio Interface
# ==========================================================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
    #gr.Markdown("## 🤖 Diffusers GitHub-Trained LoRA Chat Assistant")

    with gr.Row():
        with gr.Column(scale=2):
            chatbot = gr.Chatbot(height=480, label="Chat with LoRA")
            msg = gr.Textbox(
                placeholder="Ask about Diffusers source code, classes, or examples...",
                label="Your Message"
            )
            send = gr.Button("💬 Ask")
            clear = gr.Button("🧹 Clear Chat")
        with gr.Column(scale=1):
            log_box = gr.Textbox(label="Detailed Model Log", lines=25, interactive=False)

    send.click(chat_with_model, [msg, chatbot], [msg, chatbot, log_box])
    msg.submit(chat_with_model, [msg, chatbot], [msg, chatbot, log_box])
    clear.click(lambda: (None, None, ""), None, [chatbot, log_box], queue=False)


# ==========================================================
# Run App
# ==========================================================
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)