Spaces:
Running on Zero
Running on Zero
| import os | |
| import faiss | |
| import pickle | |
| import gradio as gr | |
| import spaces | |
| import uvicorn | |
| import threading | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import hf_hub_download | |
| from openai import OpenAI | |
| # =============================== | |
| # CONFIG | |
| # =============================== | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # OpenAI-compatible Hugging Face client | |
| client = OpenAI( | |
| base_url="https://router.huggingface.co/v1", | |
| api_key=HF_TOKEN | |
| ) | |
| MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct:fastest" | |
| # =============================== | |
| # LOAD EMBEDDINGS | |
| # =============================== | |
| print("Loading embedding model...") | |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # =============================== | |
| # LOAD RAG DATA | |
| # =============================== | |
| print("Downloading FAISS index...") | |
| index_path = hf_hub_download( | |
| repo_id="mahmodGendy/startup-llama-model", | |
| filename="faiss.index" | |
| ) | |
| index = faiss.read_index(index_path) | |
| print("Downloading documents...") | |
| docs_path = hf_hub_download( | |
| repo_id="mahmodGendy/startup-llama-model", | |
| filename="docs.pkl" | |
| ) | |
| documents = pickle.load(open(docs_path, "rb")) | |
| print("RAG system ready.") | |
| # =============================== | |
| # RAG RETRIEVAL | |
| # =============================== | |
| def retrieve_context(query, top_k=5): | |
| query_embedding = embedding_model.encode([query]) | |
| D, I = index.search(query_embedding, top_k) | |
| retrieved_docs = [documents[i] for i in I[0]] | |
| return "\n".join(retrieved_docs) | |
| # =============================== | |
| # GPU / Hosted Inference | |
| # =============================== | |
| def ask_llama(user_input): | |
| context = retrieve_context(user_input) | |
| evaluation_keywords = [ | |
| "idea", "start", "business", | |
| "startup", "viable", "launch" | |
| ] | |
| is_eval = any(w in user_input.lower() for w in evaluation_keywords) | |
| if is_eval: | |
| response_style = """ | |
| 1. Problem Validation | |
| 2. Market Evaluation | |
| 3. Risks | |
| 4. Improvement Suggestions | |
| """ | |
| else: | |
| response_style = "Respond naturally and conversationally." | |
| system_prompt = f""" | |
| You are a startup validation expert. | |
| Language Rule: | |
| - English → English | |
| - MSA Arabic → MSA Arabic | |
| - Egyptian dialect → Egyptian Arabic | |
| Context: | |
| {context} | |
| {response_style} | |
| """ | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_input} | |
| ] | |
| # Use hosted API: Hugging Face OpenAI-compatible | |
| completion = client.chat.completions.create( | |
| model=MODEL_ID, | |
| messages=messages, | |
| max_tokens=400, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| return completion.choices[0].message.content | |
| # =============================== | |
| # FASTAPI | |
| # =============================== | |
| app = FastAPI() | |
| class Query(BaseModel): | |
| question: str | |
| def ask(query: Query): | |
| answer = ask_llama(query.question) | |
| return {"answer": answer} | |
| # =============================== | |
| # GRADIO (Required for ZeroGPU) | |
| # =============================== | |
| def gradio_wrapper(question): | |
| return ask_llama(question) | |
| demo = gr.Interface( | |
| fn=gradio_wrapper, | |
| inputs=gr.Textbox(label="Ask your startup question"), | |
| outputs=gr.Textbox(label="Response") | |
| ) | |
| # =============================== | |
| # START SERVERS | |
| # =============================== | |
| def run_fastapi(): | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |
| if __name__ == "__main__": | |
| # Start FastAPI in background | |
| threading.Thread(target=run_fastapi).start() | |
| # Start Gradio (required for ZeroGPU detection) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |