|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
import gradio as gr |
|
|
import threading |
|
|
|
|
|
|
|
|
model_path = "SBK/sbk-llm-1" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" |
|
|
) |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = """You are a helpful, honest, and factual assistant trained to answer only about me *Saptarshi Bhattacharya*. You were fine-tuned on factual data derived from his work, projects, skills, internships, and engineering experiences. |
|
|
|
|
|
Your job is to help users understand what Saptarshi has done, what he's good at, and how his experience aligns with ML Ops, Data Engineering, DevOps, and related roles. |
|
|
|
|
|
- If a user asks something outside the scope of his data, do not guess — politely say it's outside your knowledge. |
|
|
- Never fabricate qualifications, names, or roles that were not in your training. |
|
|
- Emphasize Saptarshi's strengths, such as completing hard technical projects, optimizing pipelines, learning on the fly, and being a completionist. |
|
|
- Maintain a professional yet warm tone. |
|
|
- Refer to Saptarshi in third person. |
|
|
|
|
|
Your goal is to represent him truthfully and make his work accessible and understandable to potential collaborators or employers, without overselling or faking. |
|
|
""" |
|
|
|
|
|
BLOCKED_KEYWORDS = ["violence","suicide"] |
|
|
MAX_TOKENS = 512 |
|
|
|
|
|
|
|
|
def generate_response(history, system_prompt): |
|
|
|
|
|
prompt = system_prompt.strip() + "\n" |
|
|
for user, bot in history: |
|
|
prompt += f"User: {user}\nAssistant: {bot}\n" |
|
|
prompt += "User: " + history[-1][0] + "\nAssistant:" |
|
|
|
|
|
|
|
|
if any(bad in prompt.lower() for bad in BLOCKED_KEYWORDS): |
|
|
yield "[Blocked for safety. Prompt contains restricted keywords.]" |
|
|
return |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
generation_kwargs = dict( |
|
|
**inputs, |
|
|
streamer=streamer, |
|
|
max_new_tokens=MAX_TOKENS, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
partial_message = "" |
|
|
for token in streamer: |
|
|
partial_message += token |
|
|
yield partial_message |
|
|
|
|
|
|
|
|
with gr.Blocks(title="SBK LLM Chat") as demo: |
|
|
gr.Markdown("## � Chat with SBK LLM - Professional Portfolio Assistant") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
system_prompt = gr.Textbox(label="System Instructions", value=SYSTEM_PROMPT, lines=8) |
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.Chatbot(height=400) |
|
|
msg = gr.Textbox(label="Your Message", placeholder="Ask about Saptarshi's professional experience...", lines=2) |
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Submit") |
|
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
|
|
history = gr.State([]) |
|
|
|
|
|
def respond(user_message, chat_history, system_prompt): |
|
|
chat_history = chat_history + [(user_message, "")] |
|
|
|
|
|
full_response = "" |
|
|
for response in generate_response(chat_history, system_prompt): |
|
|
full_response = response |
|
|
chat_history[-1] = (user_message, full_response) |
|
|
yield chat_history |
|
|
|
|
|
return chat_history |
|
|
|
|
|
|
|
|
msg.submit( |
|
|
respond, |
|
|
[msg, chatbot, system_prompt], |
|
|
[chatbot], |
|
|
queue=True |
|
|
) |
|
|
submit_btn.click( |
|
|
respond, |
|
|
[msg, chatbot, system_prompt], |
|
|
[chatbot], |
|
|
queue=True |
|
|
) |
|
|
clear_btn.click( |
|
|
lambda: ([], []), |
|
|
outputs=[chatbot, history], |
|
|
queue=False |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue(max_size=20).launch( |
|
|
share=True, |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
show_error=True |
|
|
) |