Shreekant Kalwar (Nokia) commited on
Commit
40fce64
Β·
1 Parent(s): 4395cc5

Gemini Try

Browse files
Files changed (2) hide show
  1. app.py +14 -33
  2. app2.py +57 -0
app.py CHANGED
@@ -1,45 +1,31 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from fastapi.middleware.cors import CORSMiddleware
5
- import torch
6
  import os
 
7
 
8
- # Ensure Hugging Face cache uses a writable path
9
- os.environ["TRANSFORMERS_CACHE"] = "/app/.cache"
10
- os.environ["HF_HOME"] = "/app/.cache"
 
11
 
12
  app = FastAPI()
13
 
14
  # βœ… Allow all origins
15
  app.add_middleware(
16
  CORSMiddleware,
17
- allow_origins=["*"], # allow all origins
18
  allow_credentials=True,
19
- allow_methods=["*"], # allow all HTTP methods
20
- allow_headers=["*"], # allow all headers
21
  )
22
 
23
-
24
  class ChatRequest(BaseModel):
25
  message: str
26
 
27
- # Load DeepSeek model (small one for local use)
28
- model_name = "deepseek-ai/deepseek-coder-1.3b-base"
29
-
30
- # model_name = "deepseek-ai/deepseek-llm-7b-base"
31
-
32
- #model_name="Qwen/Qwen2.5-1.5B-Instruct"
33
- #model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
34
-
35
- print("Loading model... this may take a minute ⏳")
36
- tokenizer = AutoTokenizer.from_pretrained(model_name)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_name,
39
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
- device_map="auto"
41
- )
42
- print("Model loaded βœ…")
43
 
44
  @app.get("/")
45
  def root():
@@ -47,11 +33,6 @@ def root():
47
 
48
  @app.post("/chat")
49
  def chat(request: ChatRequest):
50
- """Chat endpoint using DeepSeek model"""
51
- inputs = tokenizer(request.message, return_tensors="pt").to(model.device)
52
- outputs = model.generate(**inputs, max_new_tokens=200)
53
-
54
- reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
-
56
-
57
- return {"reply": reply}
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ import google.generativeai as genai
5
  import os
6
+ from dotenv import load_dotenv
7
 
8
+ # Load variables from .env file
9
+ load_dotenv()
10
+ # βœ… Configure API Key (set GOOGLE_API_KEY in environment variables)
11
+ genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
12
 
13
  app = FastAPI()
14
 
15
  # βœ… Allow all origins
16
  app.add_middleware(
17
  CORSMiddleware,
18
+ allow_origins=["*"],
19
  allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
  )
23
 
 
24
  class ChatRequest(BaseModel):
25
  message: str
26
 
27
+ # βœ… Load Gemini model (example: gemini-1.5-flash is lightweight & fast)
28
+ model = genai.GenerativeModel("gemini-1.5-flash")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  @app.get("/")
31
  def root():
 
33
 
34
  @app.post("/chat")
35
  def chat(request: ChatRequest):
36
+ """Chat endpoint using Gemini"""
37
+ response = model.generate_content(request.message)
38
+ return {"reply": response.text}
 
 
 
 
 
app2.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import torch
6
+ import os
7
+
8
+ # Ensure Hugging Face cache uses a writable path
9
+ os.environ["TRANSFORMERS_CACHE"] = "/app/.cache"
10
+ os.environ["HF_HOME"] = "/app/.cache"
11
+
12
+ app = FastAPI()
13
+
14
+ # βœ… Allow all origins
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"], # allow all origins
18
+ allow_credentials=True,
19
+ allow_methods=["*"], # allow all HTTP methods
20
+ allow_headers=["*"], # allow all headers
21
+ )
22
+
23
+
24
+ class ChatRequest(BaseModel):
25
+ message: str
26
+
27
+ # Load DeepSeek model (small one for local use)
28
+ model_name = "deepseek-ai/deepseek-coder-1.3b-base"
29
+
30
+ # model_name = "deepseek-ai/deepseek-llm-7b-base"
31
+
32
+ #model_name="Qwen/Qwen2.5-1.5B-Instruct"
33
+ #model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
34
+
35
+ print("Loading model... this may take a minute ⏳")
36
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ model_name,
39
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
+ device_map="auto"
41
+ )
42
+ print("Model loaded βœ…")
43
+
44
+ @app.get("/")
45
+ def root():
46
+ return {"status": "ok"}
47
+
48
+ @app.post("/chat")
49
+ def chat(request: ChatRequest):
50
+ """Chat endpoint using DeepSeek model"""
51
+ inputs = tokenizer(request.message, return_tensors="pt").to(model.device)
52
+ outputs = model.generate(**inputs, max_new_tokens=200)
53
+
54
+ reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
+
56
+
57
+ return {"reply": reply}