janajankovic commited on
Commit
f500641
·
verified ·
1 Parent(s): 0dc282e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -28
app.py CHANGED
@@ -1,48 +1,174 @@
 
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, pipeline
3
- from peft import AutoPeftModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # HF repo of your LoRA-finetuned model (the one AutoTrain pushed)
6
- FINETUNED_MODEL_ID = "janajankovic/autotrain-juhh6-uwiv9" # <<< CHANGE THIS TO YOUR REPO ID
7
 
 
 
8
 
9
- # Load base+LoRA via PEFT
10
- model = AutoPeftModelForCausalLM.from_pretrained(FINETUNED_MODEL_ID)
11
- base_model_id = model.config.base_model_name_or_path
12
 
13
- # Use tokenizer from the base model (GaMS-1B-Chat)
14
- tokenizer = AutoTokenizer.from_pretrained(base_model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Text generation pipeline
17
- text_gen = pipeline(
 
 
 
 
 
 
 
18
  "text-generation",
19
  model=model,
20
  tokenizer=tokenizer,
21
- max_new_tokens=256,
22
- do_sample=True,
23
- temperature=0.7,
24
- top_p=0.9,
25
  )
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def respond(message, history):
29
- # message: current user message (string)
30
- # history: list of [user, assistant] pairs (ignored here, minimal chat)
31
- prompt = message
32
- outputs = text_gen(prompt, num_return_sequences=1)
33
- text = outputs[0]["generated_text"]
34
 
35
- # Many causal LM heads echo the prompt; strip it out if present
36
- if text.startswith(prompt):
37
- text = text[len(prompt):].lstrip()
38
 
39
- # ChatInterface expects a plain string here
40
- return text
41
 
 
 
 
42
 
43
  demo = gr.ChatInterface(
44
- fn=respond,
45
- title="GenUI – Slovene fine-tuned chat",
 
 
 
 
46
  )
47
 
48
  if __name__ == "__main__":
 
1
+ import os
2
+
3
  import gradio as gr
4
+ import pandas as pd
5
+ from sklearn.feature_extraction.text import TfidfVectorizer
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
+ from peft import PeftModel
10
+
11
+ # -------------------------------------------------------------------
12
+ # CONFIG
13
+ # -------------------------------------------------------------------
14
+
15
+ # Your fine-tuned adapter repo on HF
16
+ MODEL_ID = "janajankovic/autotrain-juhh6-uwiv9" # change if needed
17
 
18
+ # Base model that was fine-tuned (the one you used in AutoTrain)
19
+ BASE_MODEL_ID = "cjvt/GaMS-1B-Chat" # change if different
20
 
21
+ # CSV with chunks (already in the Space repo)
22
+ CSV_PATH = "chunks_for_autotrain.csv"
23
 
24
+ # How many *extra* chunks (besides the top-1) to add
25
+ N_NEIGHBORS = 4
 
26
 
27
+ MAX_NEW_TOKENS = 256
28
+ TEMPERATURE = 0.7
29
+ TOP_P = 0.9
30
+
31
+ # -------------------------------------------------------------------
32
+ # LOAD MODEL (BASE + PEFT ADAPTER)
33
+ # -------------------------------------------------------------------
34
+
35
+ print("Loading base model and tokenizer...")
36
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
37
+
38
+ base_model = AutoModelForCausalLM.from_pretrained(
39
+ BASE_MODEL_ID,
40
+ torch_dtype="auto",
41
+ )
42
 
43
+ # Attach LoRA / PEFT adapter
44
+ print("Loading PEFT adapter...")
45
+ model = PeftModel.from_pretrained(base_model, MODEL_ID)
46
+
47
+ # Make sure pad token is set
48
+ if model.config.pad_token_id is None and model.config.eos_token_id is not None:
49
+ model.config.pad_token_id = model.config.eos_token_id
50
+
51
+ generator = pipeline(
52
  "text-generation",
53
  model=model,
54
  tokenizer=tokenizer,
 
 
 
 
55
  )
56
 
57
+ # -------------------------------------------------------------------
58
+ # LOAD CHUNKS + BUILD TF-IDF RETRIEVER
59
+ # -------------------------------------------------------------------
60
+
61
+ print("Loading CSV chunks...")
62
+ df = pd.read_csv(CSV_PATH)
63
+ df["text"] = df["text"].fillna("")
64
+
65
+ documents = df["text"].tolist()
66
+
67
+ print("Building TF-IDF index...")
68
+ vectorizer = TfidfVectorizer(max_features=50000)
69
+ doc_matrix = vectorizer.fit_transform(documents)
70
+
71
+ # -------------------------------------------------------------------
72
+ # RETRIEVAL: TOP-1 + NEXT N_NEIGHBORS MOST SIMILAR CHUNKS
73
+ # -------------------------------------------------------------------
74
+
75
+ def retrieve_chunks(query: str, n_neighbors: int = N_NEIGHBORS):
76
+ query = query.strip()
77
+ if not query:
78
+ return []
79
+
80
+ # similarity of question vs all chunks
81
+ q_vec = vectorizer.transform([query])
82
+ sims = cosine_similarity(q_vec, doc_matrix).flatten()
83
+
84
+ if sims.max() <= 0:
85
+ return []
86
+
87
+ # indices sorted by similarity to the question (desc)
88
+ sorted_indices = sims.argsort()[::-1]
89
+
90
+ # central: most similar to question
91
+ central_idx = int(sorted_indices[0])
92
+
93
+ # neighbors: next n_neighbors most similar to question
94
+ neighbor_indices = [central_idx]
95
+ for idx in sorted_indices[1:]:
96
+ if len(neighbor_indices) >= n_neighbors + 1:
97
+ break
98
+ neighbor_indices.append(int(idx))
99
+
100
+ # keep order: central first, then neighbors
101
+ selected_texts = [documents[i] for i in neighbor_indices]
102
+ return selected_texts
103
+
104
+ def build_context(question: str) -> str:
105
+ chunks = retrieve_chunks(question, N_NEIGHBORS)
106
+ if not chunks:
107
+ return ""
108
+
109
+ # Optional: prefix chunks for clarity (not strictly needed)
110
+ labelled = []
111
+ for i, ch in enumerate(chunks):
112
+ labelled.append(f"[CHUNK {i+1}]\n{ch}")
113
+ return "\n\n".join(labelled)
114
+
115
+ # -------------------------------------------------------------------
116
+ # CHAT FUNCTION
117
+ # -------------------------------------------------------------------
118
+
119
+ SYSTEM_PROMPT = (
120
+ "Ti si pomočnik, ki odgovarja v slovenščini.\n"
121
+ "Uporabi spodnji kontekst, če je relevanten. "
122
+ "Če kontekst ne vsebuje odgovora, odgovori po svojih najboljših močeh "
123
+ "in jasno povej, da se opiraš na splošno znanje.\n"
124
+ )
125
+
126
+ def generate_answer(message: str) -> str:
127
+ context = build_context(message)
128
+
129
+ if context:
130
+ full_prompt = (
131
+ f"{SYSTEM_PROMPT}\n"
132
+ f"Kontekst:\n{context}\n\n"
133
+ f"Vprašanje uporabnika:\n{message}\n\n"
134
+ f"Odgovor (v slovenščini):\n"
135
+ )
136
+ else:
137
+ full_prompt = (
138
+ f"{SYSTEM_PROMPT}\n"
139
+ f"Vprašanje uporabnika:\n{message}\n\n"
140
+ f"Odgovor (v slovenščini):\n"
141
+ )
142
+
143
+ outputs = generator(
144
+ full_prompt,
145
+ max_new_tokens=MAX_NEW_TOKENS,
146
+ do_sample=True,
147
+ temperature=TEMPERATURE,
148
+ top_p=TOP_P,
149
+ pad_token_id=model.config.pad_token_id,
150
+ )
151
 
152
+ generated = outputs[0]["generated_text"]
 
 
 
 
 
153
 
154
+ # strip the prompt from the beginning
155
+ answer = generated[len(full_prompt):].strip()
156
+ return answer
157
 
158
+ def chat_fn(message, history):
159
+ return generate_answer(message)
160
 
161
+ # -------------------------------------------------------------------
162
+ # GRADIO UI
163
+ # -------------------------------------------------------------------
164
 
165
  demo = gr.ChatInterface(
166
+ fn=chat_fn,
167
+ title="Gen-UI fine-tuned Slovene model",
168
+ description=(
169
+ "Klepet z lastnim fine-tunanim modelom.\n"
170
+ "Model samodejno poišče najbližje besedilne 'chunke' v CSV in jih uporabi kot kontekst."
171
+ ),
172
  )
173
 
174
  if __name__ == "__main__":