Spaces:
Paused
Paused
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct" | |
| ADAPTER = "iamcodio/codio-rogerian-v1" | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| global model, tokenizer | |
| if model is None: | |
| tokenizer = AutoTokenizer.from_pretrained(ADAPTER) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, torch_dtype=torch.float16, device_map="auto" | |
| ) | |
| model = PeftModel.from_pretrained(base, ADAPTER) | |
| model.eval() | |
| def respond(message, history): | |
| load_model() | |
| messages = [] | |
| for turn in history: | |
| if isinstance(turn, dict): | |
| messages.append({"role": turn["role"], "content": turn["content"]}) | |
| elif isinstance(turn, (list, tuple)) and len(turn) == 2: | |
| messages.append({"role": "user", "content": turn[0]}) | |
| if turn[1]: | |
| messages.append({"role": "assistant", "content": turn[1]}) | |
| messages.append({"role": "user", "content": message}) | |
| tokenized = tokenizer.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True | |
| ) | |
| input_ids = tokenized["input_ids"].to(model.device) | |
| attention_mask = tokenized["attention_mask"].to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=300, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| repetition_penalty=1.1, | |
| ) | |
| raw = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| # Strip Llama 3 structured output format if present | |
| if isinstance(raw, str) and raw.startswith("[{"): | |
| try: | |
| import json | |
| parsed = json.loads(raw) | |
| if isinstance(parsed, list): | |
| raw = " ".join(item.get("text", "") for item in parsed if isinstance(item, dict)) | |
| except (json.JSONDecodeError, TypeError): | |
| pass | |
| return raw.strip() | |
| demo = gr.ChatInterface( | |
| fn=respond, | |
| title="BrainFart — Rogerian Listening Model", | |
| description="Fine-tuned Llama 3.2 3B on therapeutic conversation data. Not a therapist — just here to listen.", | |
| examples=["I've been feeling really overwhelmed lately and I don't know why", | |
| "I think I'm stuck in a cycle of self-sabotage", | |
| "Everyone keeps telling me to be positive but it's not that simple"], | |
| ) | |
| demo.launch() | |