gradprojectLLM / app.py
mahmodGendy's picture
Update app.py
c6ca0c6 verified
import os
import faiss
import pickle
import gradio as gr
import spaces
import uvicorn
import threading
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
from openai import OpenAI
# ===============================
# CONFIG
# ===============================
HF_TOKEN = os.environ.get("HF_TOKEN")
# OpenAI-compatible Hugging Face client
client = OpenAI(
base_url="https://router.huggingface.co/v1",
api_key=HF_TOKEN
)
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct:fastest"
# ===============================
# LOAD EMBEDDINGS
# ===============================
print("Loading embedding model...")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# ===============================
# LOAD RAG DATA
# ===============================
print("Downloading FAISS index...")
index_path = hf_hub_download(
repo_id="mahmodGendy/startup-llama-model",
filename="faiss.index"
)
index = faiss.read_index(index_path)
print("Downloading documents...")
docs_path = hf_hub_download(
repo_id="mahmodGendy/startup-llama-model",
filename="docs.pkl"
)
documents = pickle.load(open(docs_path, "rb"))
print("RAG system ready.")
# ===============================
# RAG RETRIEVAL
# ===============================
def retrieve_context(query, top_k=5):
query_embedding = embedding_model.encode([query])
D, I = index.search(query_embedding, top_k)
retrieved_docs = [documents[i] for i in I[0]]
return "\n".join(retrieved_docs)
# ===============================
# GPU / Hosted Inference
# ===============================
@spaces.GPU
def ask_llama(user_input):
context = retrieve_context(user_input)
evaluation_keywords = [
"idea", "start", "business",
"startup", "viable", "launch"
]
is_eval = any(w in user_input.lower() for w in evaluation_keywords)
if is_eval:
response_style = """
1. Problem Validation
2. Market Evaluation
3. Risks
4. Improvement Suggestions
"""
else:
response_style = "Respond naturally and conversationally."
system_prompt = f"""
You are a startup validation expert.
Language Rule:
- English → English
- MSA Arabic → MSA Arabic
- Egyptian dialect → Egyptian Arabic
Context:
{context}
{response_style}
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_input}
]
# Use hosted API: Hugging Face OpenAI-compatible
completion = client.chat.completions.create(
model=MODEL_ID,
messages=messages,
max_tokens=400,
temperature=0.7,
top_p=0.9
)
return completion.choices[0].message.content
# ===============================
# FASTAPI
# ===============================
app = FastAPI()
class Query(BaseModel):
question: str
@app.post("/ask")
def ask(query: Query):
answer = ask_llama(query.question)
return {"answer": answer}
# ===============================
# GRADIO (Required for ZeroGPU)
# ===============================
def gradio_wrapper(question):
return ask_llama(question)
demo = gr.Interface(
fn=gradio_wrapper,
inputs=gr.Textbox(label="Ask your startup question"),
outputs=gr.Textbox(label="Response")
)
# ===============================
# START SERVERS
# ===============================
def run_fastapi():
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
# Start FastAPI in background
threading.Thread(target=run_fastapi).start()
# Start Gradio (required for ZeroGPU detection)
demo.launch(server_name="0.0.0.0", server_port=7860)