Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import pandas as pd | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # ------------------------------------------------------------------ | |
| # CONFIG – EDIT THESE TWO LINES TO MATCH YOUR REPOS | |
| # ------------------------------------------------------------------ | |
| BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "cjvt/GaMS-1B-Chat") | |
| ADAPTER_ID = os.getenv("ADAPTER_ID", "janajankovic/autotrain-juhh6-uwiv9") | |
| CSV_PATH = "chunks_for_autotrain.csv" | |
| TOP_K = 4 | |
| MAX_INPUT_LEN = 2048 | |
| MAX_NEW_TOKENS = 256 | |
| # Enforce non-empty answers | |
| MIN_NEW_TOKENS = 32 # prevent immediate EOS / 1-4 word outputs | |
| MIN_CHARS = 60 # require roughly one sentence worth of text | |
| MAX_RETRIES = 2 | |
| # ------------------------------------------------------------------ | |
| # LOAD CSV CHUNKS + TF-IDF INDEX | |
| # ------------------------------------------------------------------ | |
| if not os.path.exists(CSV_PATH): | |
| raise FileNotFoundError(f"CSV file not found: {CSV_PATH}") | |
| df = pd.read_csv(CSV_PATH) | |
| if "chunk" in df.columns: | |
| text_col = "chunk" | |
| elif "text" in df.columns: | |
| text_col = "text" | |
| else: | |
| text_col = df.columns[0] | |
| chunks = df[text_col].astype(str).tolist() | |
| if len(chunks) == 0: | |
| raise ValueError("No chunks loaded from CSV – check the file content.") | |
| vectorizer = TfidfVectorizer(max_features=4096) | |
| tfidf_matrix = vectorizer.fit_transform(chunks) | |
| # ------------------------------------------------------------------ | |
| # LOAD MODEL + TOKENIZER (BASE + LoRA ADAPTER) | |
| # ------------------------------------------------------------------ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # CRITICAL: if prompt is too long, keep the END (question + "Odgovor:") | |
| tokenizer.truncation_side = "left" | |
| tokenizer.padding_side = "left" | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_ID, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| ) | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_ID) | |
| model = model.merge_and_unload() | |
| model.to(device) | |
| model.eval() | |
| # ------------------------------------------------------------------ | |
| # PROMPT + RETRIEVAL | |
| # ------------------------------------------------------------------ | |
| SYSTEM_PROMPT = ( | |
| "Ti si pomočnik za učitelje in odgovarjaš v slovenščini. " | |
| "Odgovarjaj kratko, jasno in brez ponavljanja istih fraz. " | |
| "Če v podanih odlomkih ni odgovora, to jasno povej." | |
| ) | |
| def retrieve_chunks(question: str, top_k: int = TOP_K): | |
| q_vec = vectorizer.transform([question]) | |
| sims = cosine_similarity(q_vec, tfidf_matrix)[0] | |
| top_idx = sims.argsort()[::-1][:top_k] | |
| return [chunks[i] for i in top_idx] | |
| def build_prompt(question: str, retrieved): | |
| context = "\n\n---\n\n".join(retrieved) | |
| prompt = ( | |
| f"{SYSTEM_PROMPT}\n\n" | |
| f"Kontekst:\n{context}\n\n" | |
| "Navodilo:\n" | |
| "Na podlagi konteksta odgovori na vprašanje NA KRATKO (3–6 stavkov). " | |
| "Ne ponavljaj istih besed ali stavkov.\n" | |
| f"Vprašanje: {question}\n\n" | |
| "Odgovor:" | |
| ) | |
| return prompt | |
| # ------------------------------------------------------------------ | |
| # GENERATION FUNCTION FOR CHAT | |
| # ------------------------------------------------------------------ | |
| def generate_answer(message: str, history): | |
| retrieved = retrieve_chunks(message, top_k=TOP_K) | |
| prompt = build_prompt(message, retrieved) | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_INPUT_LEN, | |
| ).to(device) | |
| def _generate_once(gen_kwargs: dict) -> str: | |
| with torch.no_grad(): | |
| out = model.generate(**inputs, **gen_kwargs) | |
| gen_ids = out[0][inputs["input_ids"].shape[1]:] | |
| return tokenizer.decode(gen_ids, skip_special_tokens=True).strip() | |
| base_kwargs = dict( | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.15, | |
| no_repeat_ngram_size=4, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Try to enforce minimum generation length (prevents 1–4 word answers). | |
| try_kwargs = dict(base_kwargs) | |
| try_kwargs["min_new_tokens"] = MIN_NEW_TOKENS | |
| raw_text = "" | |
| for _ in range(MAX_RETRIES + 1): | |
| try: | |
| raw_text = _generate_once(try_kwargs) | |
| except TypeError: | |
| # Older transformers: min_new_tokens not supported | |
| raw_text = _generate_once(base_kwargs) | |
| # Cleanup repeated identical lines | |
| lines = [l.strip() for l in raw_text.splitlines() if l.strip()] | |
| cleaned = [] | |
| last_line = None | |
| rep = 0 | |
| for l in lines: | |
| if l == last_line: | |
| rep += 1 | |
| if rep >= 2: | |
| continue | |
| else: | |
| rep = 0 | |
| last_line = l | |
| cleaned.append(l) | |
| answer = " ".join(cleaned).strip() or raw_text.strip() | |
| # Accept if it looks like at least one sentence | |
| if len(answer) >= MIN_CHARS and any(p in answer for p in ".!?"): | |
| return answer | |
| # Retry: loosen constraints a bit to avoid early stop / dead outputs | |
| try_kwargs["temperature"] = min(0.95, try_kwargs.get("temperature", 0.7) + 0.15) | |
| try_kwargs["top_p"] = min(0.98, try_kwargs.get("top_p", 0.9) + 0.05) | |
| try_kwargs["repetition_penalty"] = max(1.05, try_kwargs.get("repetition_penalty", 1.15) - 0.05) | |
| try_kwargs["no_repeat_ngram_size"] = max(2, try_kwargs.get("no_repeat_ngram_size", 4) - 1) | |
| # Hard fallback: guarantees at least one full sentence | |
| return "V podanih odlomkih ni dovolj informacij za zanesljiv odgovor na to vprašanje." | |
| # ------------------------------------------------------------------ | |
| # GRADIO UI | |
| # ------------------------------------------------------------------ | |
| demo = gr.ChatInterface( | |
| fn=generate_answer, | |
| title="GenUI – učiteljski pomočnik", | |
| description="Klepetalnik, prilagojen na tvoje gradivo (CSV chunki).", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |