File size: 3,323 Bytes
31e91dc
07e183d
a2ebcc6
07e183d
31e91dc
07e183d
 
 
29684f5
 
31e91dc
29684f5
07e183d
31e91dc
07e183d
9a8da89
31e91dc
07e183d
a2ebcc6
9a8da89
07e183d
 
31e91dc
9a8da89
07e183d
29684f5
 
31e91dc
9a8da89
07e183d
 
 
 
 
29684f5
07e183d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29684f5
 
07e183d
29684f5
 
 
 
31e91dc
 
07e183d
31e91dc
29684f5
9a8da89
07e183d
 
9a8da89
31e91dc
 
07e183d
 
29684f5
31e91dc
07e183d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a8da89
 
07e183d
9a8da89
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import gradio as gr

# --------------------
# Model setup
# --------------------
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
LORA_REPO = "nitya001/autotrain-4n1y9-5ekvs"

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=dtype,
    device_map="auto" if device == "cuda" else None,
)

print("Loading LoRA adapter:", LORA_REPO)
model = PeftModel.from_pretrained(base_model, LORA_REPO)
model.to(device)
model.eval()

SYSTEM_PROMPT = (
    "You are a helpful assistant fine-tuned for loan journeys and UTR queries. "
    "Answer clearly and concisely. If you don't know some specific account value, "
    "explain what information is needed instead of hallucinating numbers."
)

# --------------------
# Generation function
# --------------------
def generate_reply(message: str, history: list):
    """
    ChatInterface passes:
      message: latest user message (string)
      history: list of dicts: [{role: 'user'/'assistant', content: '...'}, ...]

    We return just the assistant's reply as a string.
    ChatInterface will handle the messages format for the UI.
    """
    # Build a simple conversation prompt using TinyLlama chat-style tags
    conversation = f"<|system|>{SYSTEM_PROMPT}</s>\n"

    if history:
        for msg in history:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            if role == "user":
                conversation += f"<|user|>{content}</s>\n"
            elif role == "assistant":
                conversation += f"<|assistant|>{content}</s>\n"

    # Add the latest user message
    conversation += f"<|user|>{message}</s>\n<|assistant|>"

    inputs = tokenizer(
        conversation,
        return_tensors="pt",
        truncation=True,
        max_length=2048,
    ).to(device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    # Take only the newly generated tokens
    generated_ids = output_ids[0][inputs["input_ids"].shape[-1] :]
    answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

    if not answer:
        answer = "I'm not sure how to answer that. Could you rephrase your question?"

    return answer

# --------------------
# Gradio UI
# --------------------
demo = gr.ChatInterface(
    fn=generate_reply,
    title="UTR & Loan Assistant (TinyLlama LoRA)",
    description=(
        "Ask things like:\n"
        "- What is my latest UTR?\n"
        "- How is my EMI calculated?\n"
        "- Summarize my repayment schedule.\n"
    ),
    examples=[
        "What is my latest UTR?",
        "Explain my repayment schedule.",
        "How are late payment charges calculated?",
    ],
)

if __name__ == "__main__":
    # Spaces will call `app.py` directly, so this is mainly for local testing
    demo.launch()