File size: 1,976 Bytes
3ef1aec a8ba867 3304f90 a8ba867 3304f90 3ef1aec 0567ace a8ba867 0567ace a8ba867 3ef1aec 0567ace 3ef1aec 1e9f3ed 3304f90 a8ba867 4a0c31e a8ba867 3304f90 4a0c31e 3304f90 0567ace 3304f90 3ef1aec 4a0c31e 3304f90 2eb97c5 3304f90 3ef1aec 3304f90 a8ba867 4a0c31e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import os
import gradio as gr
import torch
MODEL_ID = "SatyamSinghal/taskmind-1.1b-chat-lora"
HF_TOKEN = os.getenv("HF_TOKEN")
pipe = None
def load_model():
global pipe
if pipe is not None:
return
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
)
print("Loading model...")
model = AutoPeftModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
print("Model loaded successfully.")
def respond(message, history):
try:
load_model()
except Exception as e:
return f"❌ Model failed to load: {str(e)}"
messages = []
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
result = pipe(
messages,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
generated = result[0]["generated_text"]
if isinstance(generated, list):
return generated[-1]["content"]
return str(generated)
demo = gr.ChatInterface(
fn=respond,
title="TaskMind Interface",
description="Chat with the TaskMind LoRA model.",
examples=[
"Who are you?",
"@Satyam fix the growstreams deck ASAP NO Delay",
"done bhai, merged the PR",
"login page 60% ho gaya",
"getting 500 error on registration",
],
)
if __name__ == "__main__":
demo.launch() |