Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import pickle | |
| import faiss | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import InferenceClient | |
| # --- Config --- | |
| INDEX_FILE = "xkcd.index" | |
| META_FILE = "meta.pkl" | |
| CHAT_MODEL = os.getenv("CHAT_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct") | |
| # --- Build / load index --- | |
| def build_index(): | |
| print("Building FAISS index...") | |
| ds = load_dataset("olivierdehaene/xkcd", split="train") | |
| model = SentenceTransformer("all-MiniLM-L6-v2") | |
| texts = [] | |
| for ex in ds: | |
| title = ex["title"] if ex["title"] else "" | |
| transcript = ex["transcript"] if ex["transcript"] else "" | |
| explanation = ex["explanation"] if "explanation" in ex and ex["explanation"] else "" | |
| texts.append(f"{title} {transcript} {explanation}") | |
| embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True) | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dim) | |
| index.add(embeddings) | |
| faiss.write_index(index, INDEX_FILE) | |
| # Store just the metadata we need (pickle-friendly) | |
| meta = [ | |
| { | |
| "id": ex["id"], | |
| "title": ex["title"], | |
| "transcript": ex["transcript"], | |
| "explanation": ex["explanation"] if "explanation" in ex else "", | |
| } | |
| for ex in ds | |
| ] | |
| with open(META_FILE, "wb") as f: | |
| pickle.dump(meta, f) | |
| return index, meta | |
| def get_index(): | |
| if os.path.exists(INDEX_FILE) and os.path.exists(META_FILE): | |
| print("Loading cached index...") | |
| with open(META_FILE, "rb") as f: | |
| return faiss.read_index(INDEX_FILE), pickle.load(f) | |
| else: | |
| return build_index() | |
| def get_id_from_string(str:str) -> str: | |
| id_start = str.index("[") +1 | |
| id_end = str.index("]") | |
| return str[id_start:id_end] | |
| # --- Chat handler --- | |
| def respond( | |
| message: str, | |
| history: list[dict[str, str]], | |
| oauth: gr.OAuthToken | None = None, # Gradio injects this when available | |
| ): | |
| if not oauth: | |
| return "⚠️ Please sign in with your Hugging Face account (top of the page)" | |
| token = oauth.token | |
| # Embed the query and search FAISS | |
| query_vec = embedder.encode([message], convert_to_numpy=True) | |
| D, I = index.search(query_vec, 5) | |
| candidates = [meta[int(i)] for i in I[0]] | |
| context = "\n".join( | |
| f"[{c['id']}] {c['title']}\nTranscript: {c['transcript']}\nExplanation: {c['explanation']}" | |
| for c in candidates | |
| ) | |
| prompt = f"""Situation: "{message}" | |
| Here are candidate xkcd comics: | |
| {context} | |
| Which comic fits best and why? | |
| Please answer with the comic ID, URL (https://xkcd.com/ID/) and a short explanation in the format: | |
| [ID] URL | |
| EXPLANATION | |
| """ | |
| print("[PROMPT] " + prompt) | |
| client = InferenceClient(model=CHAT_MODEL, api_key=token) # 'api_key' alias also works | |
| resp = client.chat_completion( | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant that selects the most suitable xkcd comic."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| max_tokens=200, | |
| temperature=0.0, # TODO | |
| ) | |
| # Be tolerant to slight schema differences | |
| try: | |
| choice = resp.choices[0] | |
| msg = getattr(choice, "message", None) | |
| if isinstance(msg, dict): | |
| out = msg.get("content", "") | |
| else: | |
| out = getattr(msg, "content", "") or getattr(choice, "text", "") | |
| except Exception: | |
| out = str(resp) | |
| out_text = out.strip() or "Sorry, I couldn't parse the model response." | |
| if out_text != "Sorry, I couldn't parse the model response.": | |
| try: | |
| id = get_id_from_string(out_text) | |
| print(f'Read ID: {id}') | |
| import urllib.request, json | |
| with urllib.request.urlopen(f'https://xkcd.com/{id}/info.0.json') as url: | |
| img_url = json.load(url)["img"] | |
| print(f'Got image url: {img_url}') | |
| return [out_text, gr.Image(value=img_url)] | |
| except ValueError: | |
| print("Couldn't parse xkcd ID or get image! That should not happen.") | |
| return out_text | |
| if __name__ == "__main__": | |
| # --- UI --- | |
| with gr.Blocks(theme='gstaff/xkcd') as demo: | |
| gr.Markdown("# xkcd Comic Finder") | |
| gr.Markdown( | |
| "Sign in with your Hugging Face account so the app can call the model via the Inference API." | |
| "\n\n> If you deploy to a Space, add `hf_oauth: true` in your Space metadata and grant the `inference:api` scope." | |
| ) | |
| gr.LoginButton() # Shows “Sign in with Hugging Face” | |
| gr.ChatInterface( | |
| fn=respond, | |
| title="xkcd Comic Finder", | |
| description="Find the most suitable xkcd comic for your situation. Use the login button above.", | |
| examples=[ | |
| "I need a comic about procrastination.", | |
| "A comic for programmers debugging code.", | |
| "Life advice in comic form.", | |
| ], | |
| type="messages", | |
| ) | |
| global index | |
| global meta | |
| index, meta = get_index() | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| demo.launch() | |